Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 13 Jul 2024 05:21:40 +0000 (07:21 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 13 Jul 2024 05:21:40 +0000 (07:21 +0200)
main.py

diff --git a/main.py b/main.py
index a8ceac8..9599cf3 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -78,10 +78,6 @@ parser.add_argument("--gpus", type=str, default="all")
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
-parser.add_argument("--min_to_validate", type=int, default=None)
-
-parser.add_argument("--max_to_validate", type=int, default=None)
-
 parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
 
 parser.add_argument("--proba_understands", type=float, default=0.99)
@@ -121,12 +117,6 @@ parser.add_argument("--sky_speed", type=int, default=3)
 
 args = parser.parse_args()
 
-if args.min_to_validate is None:
-    args.min_to_validate = args.nb_gpts - 1
-
-if args.max_to_validate is None:
-    args.max_to_validate = args.nb_gpts - 1
-
 if args.result_dir is None:
     args.result_dir = f"results_culture"
 
@@ -338,10 +328,10 @@ def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_de
 
 
 def one_epoch(model, quiz_machine, local_device=main_device):
-    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
     model.to(local_device).train()
 
+    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
     nb_train_samples, acc_train_loss = 0, 0.0
 
     for input in quiz_machine.batches(model, split="train"):