Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 20 Jul 2024 21:47:51 +0000 (23:47 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 20 Jul 2024 21:47:51 +0000 (23:47 +0200)
grids.py
main.py
quiz_machine.py

index 4db12db..bbb18d2 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -135,10 +135,10 @@ class Grids(problem.Problem):
 
         if shape == "fwd_3_bck_123":
             forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
-            backward_mask = ((T % (S + 1) != 0) & (T >= S + 1)).long()
+            backward_mask = ((T % (S + 1) != 0) & (T >= 1 * (S + 1))).long()
         elif shape == "fwd_012_bck_0":
             forward_mask = ((T % (S + 1) != 0) & (T < 3 * (S + 1))).long()
-            backward_mask = ((T % (S + 1) != 0) & (T < S + 1)).long()
+            backward_mask = ((T % (S + 1) != 0) & (T < 1 * (S + 1))).long()
         elif shape == "fwd_3_bck_3":
             forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
             backward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
@@ -1277,7 +1277,6 @@ class Grids(problem.Problem):
         S = self.height * self.width
         Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1]
         f_Bs = answers[:, 1:]
-        print(f"{prompts.size()=} {answers.size()=} {Bs.size()=} {f_Bs.size()=}")
         return (Bs == f_Bs).long().min(dim=-1).values > 0
 
     def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
@@ -1371,8 +1370,8 @@ if __name__ == "__main__":
     nb, nrow = 8, 2
     # nb, nrow = 8, 2
 
-    for t in grids.all_tasks:
-        # for t in [grids.task_compute]:
+    for t in grids.all_tasks:
+    for t in [grids.task_convex]:
         print(t.__name__)
         prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         # prompts[...] = torch.randint(grids.nb_token_values(), prompts.size())
diff --git a/main.py b/main.py
index 562a95d..c9c30c3 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -558,6 +558,8 @@ for k in range(args.nb_gpts):
 
 ######################################################################
 
+current_epoch = 0
+
 if args.resume:
     try:
         for model in models:
@@ -580,6 +582,15 @@ if args.resume:
             log_string(f"cannot find {filename}")
             pass
 
+        try:
+            filename = "state.pth"
+            state = torch.load(os.path.join(args.result_dir, filename))
+            log_string(f"successfully loaded {filename}")
+            current_epoch = state["current_epoch"]
+        except FileNotFoundError:
+            log_string(f"cannot find {filename}")
+            pass
+
     except:
         log_string(f"error when loading {filename}.")
         exit(1)
@@ -616,7 +627,7 @@ if args.dirty_debug:
 
 ######################################################################
 
-for n_epoch in range(args.nb_epochs):
+for n_epoch in range(current_epoch, args.nb_epochs):
     log_string(f"--- epoch {n_epoch} ----------------------------------------")
 
     cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
@@ -675,6 +686,11 @@ for n_epoch in range(args.nb_epochs):
         )
         log_string(f"wrote {filename}")
 
+    state = {"current_epoch": n_epoch}
+    filename = "state.pth"
+    torch.save(state, os.path.join(args.result_dir, filename))
+    log_string(f"wrote {filename}")
+
     # Renew the training samples
 
     for model in weakest_models:
index 91eb3ac..c006ea4 100755 (executable)
@@ -447,9 +447,7 @@ class QuizMachine:
 
     ###############################################################
 
-    def solution_nb_correct(
-        self, models_for_validation, c_quizzes, bidirectional_validation=True
-    ):
+    def solution_nb_correct(self, models_for_validation, c_quizzes):
         seq_logproba = torch.zeros(
             c_quizzes.size(0),
             max([m.id for m in models_for_validation]) + 1,
@@ -457,6 +455,12 @@ class QuizMachine:
         )
 
         nb_correct = 0
+        correct_models = torch.empty(
+            c_quizzes.size(0),
+            max([m.id for m in models_for_validation]) + 1,
+            device=self.device,
+            dtype=torch.int64,
+        )
 
         seq_logproba[...] = 0.0
 
@@ -478,7 +482,9 @@ class QuizMachine:
                 device=self.device,
             )
 
-            correct = (c_quizzes == result).long().min(dim=-1).values
+            correct_models[:, model.id] = (
+                (c_quizzes == result).long().min(dim=-1).values
+            )
 
             # -------------------------------
 
@@ -501,13 +507,17 @@ class QuizMachine:
                 device=self.device,
             )
 
-            flipped_correct = (c_quizzes == result).long().min(dim=-1).values
+            correct_models[:, model.id] *= (
+                (c_quizzes == result).long().min(dim=-1).values
+            )
 
             # -------------------------------
 
-            nb_correct += correct * flipped_correct
+        i = correct_models.sum(dim=1) == correct_models.size(1) - 1
+        c = (correct_models[i] == 0).long().sum(dim=0)
+        self.logger(f"nb_failures_on_validated {tuple(x.item() for x in c)}")
 
-        return nb_correct.to("cpu")
+        return correct_models.sum(dim=1).to("cpu")
 
     ###############################################################