Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 23 Sep 2024 06:15:48 +0000 (08:15 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 23 Sep 2024 06:15:48 +0000 (08:15 +0200)
grids.py
main.py

index 0f7e554..4c132c3 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -146,8 +146,8 @@ class Grids(problem.Problem):
 
     # grid_gray = 192
     # thickness = 0
-    # background_gray = 255
-    # dots = True
+    # background_gray = 240
+    # dots = False
 
     named_colors = [
         ("white", [background_gray, background_gray, background_gray]),
@@ -1820,7 +1820,7 @@ class Grids(problem.Problem):
             print(t.__name__)
             quizzes = self.generate_w_quizzes_(nb, tasks=[t])
             self.save_quizzes_as_image(
-                result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow
+                result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow, delta=True
             )
 
 
@@ -1835,6 +1835,44 @@ if __name__ == "__main__":
 
     nb, nrow = 64, 4
     # nb, nrow = 8, 2
+    nb_rows = 13
+
+    c_quizzes = torch.load("/home/fleuret/state.pth")["train_c_quizzes"]
+    c_quizzes = c_quizzes[torch.randperm(c_quizzes.size(0))[: nrow * nb_rows]]
+
+    grids.save_quizzes_as_image(
+        "/tmp",
+        "c_quizzes.png",
+        c_quizzes,
+        # delta=True,
+        nrow=nrow,
+        margin=10,
+        # grids=False
+        # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
+    )
+
+    w_quizzes = grids.generate_w_quizzes_(
+        nrow * nb_rows,
+        tasks=[
+            grids.task_replace_color,
+            grids.task_translate,
+            grids.task_grow,
+            grids.task_frame,
+        ],
+    )
+
+    grids.save_quizzes_as_image(
+        "/tmp",
+        "w_quizzes.png",
+        w_quizzes,
+        # delta=True,
+        nrow=nrow,
+        margin=10,
+        # grids=False
+        # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
+    )
+
+    exit(0)
 
     # for t in grids.all_tasks:
 
diff --git a/main.py b/main.py
index 00722d6..4fee0f2 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -47,7 +47,7 @@ parser.add_argument("--eval_batch_size", type=int, default=25)
 
 parser.add_argument("--nb_train_samples", type=int, default=50000)
 
-parser.add_argument("--nb_test_samples", type=int, default=2500)
+parser.add_argument("--nb_test_samples", type=int, default=10000)
 
 parser.add_argument("--nb_c_quizzes", type=int, default=5000)
 
@@ -605,7 +605,10 @@ def one_complete_epoch(
     # Compute the test accuracy
 
     quizzes = generate_quiz_set(
-        args.nb_test_samples, test_c_quizzes, args.c_quiz_multiplier
+        args.nb_test_samples,
+        c_quizzes=None,
+        c_quiz_multiplier=args.c_quiz_multiplier
+        # args.nb_test_samples, test_c_quizzes, args.c_quiz_multiplier
     )
     imt_set = samples_for_prediction_imt(quizzes.to(local_device))
     result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
@@ -658,8 +661,6 @@ def evaluate_quizzes(quizzes, models, with_hints, local_device):
         nb_correct += (nb_mistakes == 0).long()
         nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long()
 
-    # print("\n\n", nb_correct, nb_wrong)
-
     return nb_correct, nb_wrong
 
 
@@ -796,6 +797,7 @@ def save_models(models, suffix=""):
                 "state_dict": model.state_dict(),
                 "optimizer_state_dict": model.optimizer.state_dict(),
                 "test_accuracy": model.test_accuracy,
+                "nb_epochs": model.nb_epochs,
             },
             os.path.join(args.result_dir, filename),
         )
@@ -830,6 +832,35 @@ def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
     log_string(f"wrote {filename}")
 
 
+######################################################################
+
+
+def new_model(i):
+    if args.model_type == "standard":
+        model_constructor = attae.AttentionAE
+    elif args.model_type == "functional":
+        model_constructor = attae.FunctionalAttentionAE
+    else:
+        raise ValueError(f"Unknown model type {args.model_type}")
+
+    model = model_constructor(
+        vocabulary_size=vocabulary_size * 2,
+        dim_model=args.dim_model,
+        dim_keys=args.dim_keys,
+        dim_hidden=args.dim_hidden,
+        nb_heads=args.nb_heads,
+        nb_blocks=args.nb_blocks,
+        dropout=args.dropout,
+    )
+
+    model.id = i
+    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+    model.test_accuracy = 0.0
+    model.nb_epochs = 0
+
+    return model
+
+
 ######################################################################
 
 problem = grids.Grids(
@@ -853,31 +884,10 @@ log_string(f"vocabulary_size {vocabulary_size}")
 
 models = []
 
-if args.model_type == "standard":
-    model_constructor = attae.AttentionAE
-elif args.model_type == "functional":
-    model_constructor = attae.FunctionalAttentionAE
-else:
-    raise ValueError(f"Unknown model type {args.model_type}")
-
-
 for i in range(args.nb_models):
-    model = model_constructor(
-        vocabulary_size=vocabulary_size * 2,
-        dim_model=args.dim_model,
-        dim_keys=args.dim_keys,
-        dim_hidden=args.dim_hidden,
-        nb_heads=args.nb_heads,
-        nb_blocks=args.nb_blocks,
-        dropout=args.dropout,
-    )
-
+    model = new_model(i)
     # model = torch.compile(model)
 
-    model.id = i
-    model.test_accuracy = 0.0
-    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
     models.append(model)
 
 ######################################################################
@@ -896,6 +906,8 @@ if args.resume:
         model.load_state_dict(d["state_dict"])
         model.optimizer.load_state_dict(d["optimizer_state_dict"])
         model.test_accuracy = d["test_accuracy"]
+        model.nb_epochs = d["nb_epochs"]
+
         log_string(f"successfully loaded {filename}")
 
     filename = "state.pth"
@@ -972,6 +984,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
         # The c quizzes used to estimate the test accuracy have to be
         # solvable without hints
+
         nb_correct, _ = evaluate_quizzes(
             quizzes=train_c_quizzes,
             models=models,