Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 21 Jun 2024 19:58:14 +0000 (21:58 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 21 Jun 2024 19:58:14 +0000 (21:58 +0200)
main.py
world.py

diff --git a/main.py b/main.py
index 4a1207d..61d77ed 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -32,7 +32,7 @@ parser = argparse.ArgumentParser(
 parser.add_argument(
     "--task",
     type=str,
-    default="twotargets",
+    default="world",
     help="file, byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp, greed",
 )
 
@@ -46,7 +46,7 @@ parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
 
 ########################################
 
-parser.add_argument("--nb_epochs", type=int, default=50)
+parser.add_argument("--nb_epochs", type=int, default=100)
 
 parser.add_argument("--batch_size", type=int, default=None)
 
@@ -60,7 +60,7 @@ parser.add_argument("--optim", type=str, default="adam")
 
 parser.add_argument("--learning_rate", type=float, default=1e-4)
 
-parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6")
+parser.add_argument("--learning_rate_schedule", type=str, default=None)
 
 ########################################
 
@@ -374,9 +374,8 @@ else:
 try:
     os.mkdir(args.result_dir)
 except FileExistsError:
-    if not args.resume:
-        print(f"result directory {args.result_dir} already exists")
-        exit(1)
+    print(f"result directory {args.result_dir} already exists")
+    exit(1)
 
 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
 
@@ -745,12 +744,15 @@ if args.learning_rate_schedule == "cos":
         u = n_epoch / args.nb_epochs * math.pi
         learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
 else:
-    u = {
-        int(k): float(v)
-        for k, v in [
-            tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
-        ]
-    }
+    if args.learning_rate_schedule is not None:
+        u = {
+            int(k): float(v)
+            for k, v in [
+                tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
+            ]
+        }
+    else:
+        u = {}
 
     learning_rate_schedule = {}
     learning_rate = args.learning_rate
@@ -890,19 +892,19 @@ def create_quizzes(
 
 ######################################################################
 
-accuracy_to_make_quizzes = 0.95
+accuracy_to_make_quizzes = 0.975
 
-for n_epoch in range(nb_epochs_finished, args.nb_epochs):
+for n_epoch in range(args.nb_epochs):
     learning_rate = learning_rate_schedule[n_epoch]
 
     for m in models:
         one_epoch(m, task, learning_rate)
         test_accuracy = run_tests(m, task, deterministic_synthesis=False)
 
-    if test_accuracy >= accuracy_to_make_quizzes:
-        other_models = models.copy()
-        other_models.remove(model)
-        create_quizzes(other_models, task)
+        if test_accuracy >= accuracy_to_make_quizzes:
+            other_models = models.copy()
+            other_models.remove(m)
+            create_quizzes(m, other_models, task)
 
     # --------------------------------------------
 
index 43126d5..89833e6 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -22,7 +22,7 @@ colors = torch.tensor(
         [255, 0, 0],
         [0, 128, 0],
         [0, 0, 255],
-        [255, 255, 0],
+        [255, 200, 0],
         [192, 192, 192],
     ]
 )