Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 21 Sep 2024 03:12:48 +0000 (05:12 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 21 Sep 2024 03:12:48 +0000 (05:12 +0200)
grids.py
main.py

index e5890ca..78d9297 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -287,7 +287,8 @@ class Grids(problem.Problem):
     ######################################################################
 
     def vocabulary_size(self):
-        warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning)
+        # warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning)
+        # return self.nb_colors+4
         return self.nb_colors
 
     def grid2img(self, x, scale=15, grids=True):
diff --git a/main.py b/main.py
index 961ae81..21666d1 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -47,7 +47,7 @@ parser.add_argument("--eval_batch_size", type=int, default=25)
 
 parser.add_argument("--nb_train_samples", type=int, default=50000)
 
-parser.add_argument("--nb_test_samples", type=int, default=1000)
+parser.add_argument("--nb_test_samples", type=int, default=2500)
 
 parser.add_argument("--nb_c_quizzes", type=int, default=5000)
 
@@ -252,6 +252,25 @@ assert args.nb_test_samples % args.batch_size == 0
 ######################################################################
 
 
+def optimizer_to(optim, device):
+    """Move the optimizer optim to the device"""
+    for param in optim.state.values():
+        # Not sure there are any global tensors in the state dict
+        if isinstance(param, torch.Tensor):
+            param.data = param.data.to(device)
+            if param._grad is not None:
+                param._grad.data = param._grad.data.to(device)
+        elif isinstance(param, dict):
+            for subparam in param.values():
+                if isinstance(subparam, torch.Tensor):
+                    subparam.data = subparam.data.to(device)
+                    if subparam._grad is not None:
+                        subparam._grad.data = subparam._grad.data.to(device)
+
+
+######################################################################
+
+
 def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
     if c_quizzes is None:
         quizzes = problem.generate_w_quizzes(nb_samples)
@@ -290,25 +309,6 @@ def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
 ######################################################################
 
 
-def optimizer_to(optim, device):
-    """Move the optimizer optim to the device"""
-    for param in optim.state.values():
-        # Not sure there are any global tensors in the state dict
-        if isinstance(param, torch.Tensor):
-            param.data = param.data.to(device)
-            if param._grad is not None:
-                param._grad.data = param._grad.data.to(device)
-        elif isinstance(param, dict):
-            for subparam in param.values():
-                if isinstance(subparam, torch.Tensor):
-                    subparam.data = subparam.data.to(device)
-                    if subparam._grad is not None:
-                        subparam._grad.data = subparam._grad.data.to(device)
-
-
-######################################################################
-
-
 def add_hints_imt(imt_set):
     """Set every component of the mask to zero with probability
     args.proba_hint, and for each component set to zero, copy the
@@ -589,10 +589,12 @@ def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_de
 ######################################################################
 
 
-def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
-    one_epoch(model, n_epoch, c_quizzes, train=True, local_device=local_device)
+def one_complete_epoch(
+    model, n_epoch, train_c_quizzes, test_c_quizzes, local_device=main_device
+):
+    one_epoch(model, n_epoch, train_c_quizzes, train=True, local_device=local_device)
 
-    one_epoch(model, n_epoch, c_quizzes, train=False, local_device=local_device)
+    one_epoch(model, n_epoch, test_c_quizzes, train=False, local_device=local_device)
 
     # Compute the test accuracy
 
@@ -627,7 +629,7 @@ def max_nb_mistakes_on_one_grid(quizzes, prediction):
     )
 
 
-def evaluate_quizzes(quizzes, models, local_device):
+def evaluate_quizzes(quizzes, models, with_hints, local_device):
     nb_correct, nb_wrong = 0, 0
 
     for model in models:
@@ -636,20 +638,16 @@ def evaluate_quizzes(quizzes, models, local_device):
             model=model,
             input=quizzes,
             with_noise=False,
-            with_hints=True,
+            with_hints=with_hints,
             local_device=local_device,
         )
         nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, predicted)
         nb_correct += (nb_mistakes == 0).long()
         nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long()
 
-    to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
-        nb_wrong >= args.nb_have_to_be_wrong
-    )
-
     # print("\n\n", nb_correct, nb_wrong)
 
-    return to_keep, nb_correct, nb_wrong
+    return nb_correct, nb_wrong
 
 
 ######################################################################
@@ -686,12 +684,17 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
             # Select the ones that are solved properly by some models and
             # not understood by others
 
-            to_keep, nb_correct, nb_wrong = evaluate_quizzes(
+            nb_correct, nb_wrong = evaluate_quizzes(
                 quizzes=c_quizzes,
                 models=models,
+                with_hints=True,
                 local_device=local_device,
             )
 
+            to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
+                nb_wrong >= args.nb_have_to_be_wrong
+            )
+
             nb_validated += to_keep.long().sum().item()
             record.append(c_quizzes[to_keep])
 
@@ -743,8 +746,8 @@ def multithread_execution(fun, arguments):
 
     for args in arguments:
         # To get a different sequence between threads
-        log_string(f"dummy_rand {torch.rand(1)}")
-        torch.rand(1)
+        log_string(f"dummy_rand {torch.rand(1)}")
+        torch.rand(1)
         t = threading.Thread(target=threadable_fun, daemon=True, args=args)
         threads.append(t)
         t.start()
@@ -787,9 +790,10 @@ def save_models(models, suffix=""):
 def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
     c_quizzes = c_quizzes.to(local_device)
 
-    to_keep, nb_correct, nb_wrong = evaluate_quizzes(
+    nb_correct, nb_wrong = evaluate_quizzes(
         quizzes=c_quizzes,
         models=models,
+        with_hints=False,
         local_device=local_device,
     )
 
@@ -873,10 +877,6 @@ if args.resume:
         model.load_state_dict(d["state_dict"])
         model.optimizer.load_state_dict(d["optimizer_state_dict"])
         model.test_accuracy = d["test_accuracy"]
-        # model.gen_test_accuracy = d["gen_test_accuracy"]
-        # model.gen_state_dict = d["gen_state_dict"]
-        # model.train_c_quiz_bags = d["train_c_quiz_bags"]
-        # model.test_c_quiz_bags = d["test_c_quiz_bags"]
         log_string(f"successfully loaded {filename}")
 
     filename = "state.pth"
@@ -889,7 +889,8 @@ if args.resume:
     log_string(f"successfully loaded {filename}")
 
     current_epoch = state["current_epoch"]
-    c_quizzes = state["c_quizzes"]
+    train_c_quizzes = state["train_c_quizzes"]
+    test_c_quizzes = state["test_c_quizzes"]
 
 ######################################################################
 
@@ -899,7 +900,7 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
-main_c_quizzes = None
+train_c_quizzes, test_c_quizzes = None, None
 
 ######################################################################
 
@@ -908,7 +909,8 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     state = {
         "current_epoch": n_epoch,
-        "main_c_quizzes": main_c_quizzes,
+        "train_c_quizzes": train_c_quizzes,
+        "test_c_quizzes": test_c_quizzes,
     }
 
     filename = "state.pth"
@@ -925,7 +927,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     lowest_test_accuracy = min([float(m.test_accuracy) for m in models])
 
     if lowest_test_accuracy >= args.accuracy_to_make_c_quizzes:
-        if main_c_quizzes is None:
+        if train_c_quizzes is None:
             save_models(models, "naive")
 
         nb_gpus = len(gpus)
@@ -942,20 +944,29 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
         log_string(f"generated_c_quizzes {new_c_quizzes.size()}")
 
-        main_c_quizzes = (
+        train_c_quizzes = (
             new_c_quizzes
-            if main_c_quizzes is None
-            else torch.cat([main_c_quizzes, new_c_quizzes])
+            if train_c_quizzes is None
+            else torch.cat([train_c_quizzes, new_c_quizzes])
         )
-        main_c_quizzes = main_c_quizzes[-args.nb_train_samples :]
+        train_c_quizzes = train_c_quizzes[-args.nb_train_samples :]
+
+        nb_correct, _ = evaluate_quizzes(
+            quizzes=train_c_quizzes,
+            models=models,
+            with_hints=False,
+            local_device=local_device,
+        )
+
+        test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct]
 
         for model in models:
             model.test_accuracy = 0
 
-    if main_c_quizzes is None:
+    if train_c_quizzes is None:
         log_string("no_c_quiz")
     else:
-        log_string(f"nb_c_quizzes {main_c_quizzes.size(0)}")
+        log_string(f"nb_c_quizzes {train_c_quizzes.size(0)}")
 
     # --------------------------------------------------------------------
 
@@ -969,7 +980,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     multithread_execution(
         one_complete_epoch,
         [
-            (model, n_epoch, main_c_quizzes, gpu)
+            (model, n_epoch, train_c_quizzes, test_c_quizzes, gpu)
             for model, gpu in zip(weakest_models, gpus)
         ],
     )