Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 06:54:11 +0000 (08:54 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 06:54:11 +0000 (08:54 +0200)
main.py
tasks.py

diff --git a/main.py b/main.py
index 11d712a..6c27599 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -46,7 +46,7 @@ parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
 
 ########################################
 
-parser.add_argument("--nb_epochs", type=int, default=100)
+parser.add_argument("--nb_epochs", type=int, default=10000)
 
 parser.add_argument("--batch_size", type=int, default=None)
 
@@ -56,12 +56,8 @@ parser.add_argument("--nb_train_samples", type=int, default=None)
 
 parser.add_argument("--nb_test_samples", type=int, default=None)
 
-parser.add_argument("--optim", type=str, default="adam")
-
 parser.add_argument("--learning_rate", type=float, default=1e-4)
 
-parser.add_argument("--learning_rate_schedule", type=str, default=None)
-
 ########################################
 
 parser.add_argument("--model", type=str, default=None)
@@ -716,43 +712,9 @@ if args.max_percents_of_test_in_train >= 0:
 
 ##############################
 
-if args.learning_rate_schedule == "cos":
-    learning_rate_schedule = {}
-    for n_epoch in range(args.nb_epochs):
-        u = n_epoch / args.nb_epochs * math.pi
-        learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
-else:
-    if args.learning_rate_schedule is not None:
-        u = {
-            int(k): float(v)
-            for k, v in [
-                tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
-            ]
-        }
-    else:
-        u = {}
-
-    learning_rate_schedule = {}
-    learning_rate = args.learning_rate
-    for n_epoch in range(args.nb_epochs):
-        if n_epoch in u:
-            learning_rate = u[n_epoch]
-        learning_rate_schedule[n_epoch] = learning_rate
-
-log_string(f"learning_rate_schedule {learning_rate_schedule}")
-
-######################################################################
-
 
 def one_epoch(model, task):
-    if args.optim == "sgd":
-        optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
-    elif args.optim == "adam":
-        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-    elif args.optim == "adamw":
-        optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
-    else:
-        raise ValueError(f"Unknown optimizer {args.optim}.")
+    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
 
     model.train()
 
@@ -851,7 +813,7 @@ def create_quizzes(
     task.save_image(
         new_quizzes[:96],
         args.result_dir,
-        f"world_new_{n_epoch:04d}.png",
+        f"world_new_{n_epoch:04d}_{model.id:02d}.png",
         log_string,
     )
 
index b4829d9..5d9a018 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -2236,7 +2236,10 @@ class World(Task):
         )
 
         self.save_image(
-            result[:96], result_dir, f"world_result_{n_epoch:04d}.png", logger
+            result[:96],
+            result_dir,
+            f"world_result_{n_epoch:04d}_{model.id:02d}.png",
+            logger,
         )
 
         return main_test_accuracy