Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 20 Sep 2024 06:38:30 +0000 (08:38 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 20 Sep 2024 06:38:30 +0000 (08:38 +0200)
grids.py
main.py

index 0613043..197eb5a 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -144,9 +144,9 @@ class Grids(problem.Problem):
     # background_gray=240
     # dots = False
 
-    # grid_gray = 200
+    # grid_gray = 192
     # thickness = 0
-    # background_gray = 240
+    # background_gray = 255
     # dots = True
 
     named_colors = [
@@ -287,7 +287,8 @@ class Grids(problem.Problem):
     ######################################################################
 
     def vocabulary_size(self):
-        return self.nb_colors
+        warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning)
+        return self.nb_colors + 4
 
     def grid2img(self, x, scale=15, grids=True):
         m = torch.logical_and(x >= 0, x < self.nb_colors).long()
@@ -313,13 +314,12 @@ class Grids(problem.Problem):
                 :,
                 :,
                 :,
-                scale // 2 - 2 : scale // 2 + 1,
+                scale // 2 - 1 : scale // 2 + 2,
                 :,
-                scale // 2 - 2 : scale // 2 + 1,
+                scale // 2 - 1 : scale // 2 + 2,
             ]
-            z[...] = (z == self.background_gray) * self.grid_gray + (
-                z != self.background_gray
-            ) * z
+            zz = (z == self.background_gray).min(dim=1, keepdim=True).values
+            z[...] = zz * self.grid_gray + (zz == False) * z
 
         for n in range(m.size(0)):
             for i in range(m.size(1)):
@@ -367,7 +367,7 @@ class Grids(problem.Problem):
         comment_height=48,
         nrow=4,
         grids=True,
-        margin=8,
+        margin=12,
         delta=False,
     ):
         quizzes = quizzes.to("cpu")
@@ -446,10 +446,12 @@ class Grids(problem.Problem):
                     + (1 - predicted_parts[:, :, None]) * white[None, None, :]
                 )
 
-        img_A = self.add_frame(img_A, colors[:, 0], thickness=8)
-        img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=8)
-        img_B = self.add_frame(img_B, colors[:, 2], thickness=8)
-        img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=8)
+        separation = 6
+
+        img_A = self.add_frame(img_A, colors[:, 0], thickness=separation)
+        img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=separation)
+        img_B = self.add_frame(img_B, colors[:, 2], thickness=separation)
+        img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=separation)
 
         img_A = self.add_frame(img_A, white[None, :], thickness=2)
         img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2)
@@ -457,9 +459,13 @@ class Grids(problem.Problem):
         img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2)
 
         if delta:
-            img_delta_A = self.add_frame(img_delta_A, colors[:, 0], thickness=8)
+            img_delta_A = self.add_frame(
+                img_delta_A, colors[:, 0], thickness=separation
+            )
             img_delta_A = self.add_frame(img_delta_A, white[None, :], thickness=2)
-            img_delta_B = self.add_frame(img_delta_B, colors[:, 0], thickness=8)
+            img_delta_B = self.add_frame(
+                img_delta_B, colors[:, 0], thickness=separation
+            )
             img_delta_B = self.add_frame(img_delta_B, white[None, :], thickness=2)
             img = torch.cat(
                 [img_A, img_f_A, img_delta_A, img_B, img_f_B, img_delta_B], dim=3
diff --git a/main.py b/main.py
index 52505de..06dfc5e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -272,9 +272,7 @@ def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
             c_quizzes = c_quizzes[i]
 
         w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
-        w_quizzes = w_quizzes.view(w_quizzes.size(0), 4, -1)[:, :, 1:].reshape(
-            w_quizzes.size(0), -1
-        )
+
         quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
         nb_w_quizzes = w_quizzes.size(0)
         nb_c_quizzes = c_quizzes.size(0)
@@ -383,7 +381,9 @@ def ae_predict(model, imt_set, local_device=main_device, desc="predict"):
     return torch.cat(record)
 
 
-def predict_full(model, input, with_perturbations=False, local_device=main_device):
+def predict_full(
+    model, input, with_noise=False, with_hints=False, local_device=main_device
+):
     input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
     nb = input.size(0)
     masks = input.new_zeros(input.size())
@@ -393,8 +393,10 @@ def predict_full(model, input, with_perturbations=False, local_device=main_devic
     input = (1 - masks) * targets
     imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
-    if with_perturbations:
+    if with_hints:
         imt_set = add_hints_imt(imt_set)
+
+    if with_noise:
         imt_set = add_noise_imt(imt_set)
 
     result = ae_predict(model, imt_set, local_device=local_device, desc=None)
@@ -563,7 +565,13 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     problem.save_quizzes_as_image(
         args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes
     )
-    result = predict_full(model=model, input=quizzes, local_device=local_device)
+    result = predict_full(
+        model=model,
+        input=quizzes,
+        with_noise=True,
+        with_hints=True,
+        local_device=local_device,
+    )
     problem.save_quizzes_as_image(
         args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result
     )
@@ -630,27 +638,30 @@ def evaluate_quizzes(quizzes, models, local_device):
         result = predict_full(
             model=model,
             input=quizzes,
-            with_perturbations=True,
+            with_noise=False,
+            with_hints=True,
             local_device=local_device,
         )
 
-        nb_correct += (max_nb_mistakes_on_one_grid(quizzes, result) == 0).long()
+        nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, result)
+        nb_correct += (nb_mistakes == 0).long()
 
-        result = predict_full(
-            model=model,
-            input=quizzes,
-            with_perturbations=False,
-            local_device=local_device,
-        )
+        # result = predict_full(
+        # model=model,
+        # input=quizzes,
+        # with_noise=False,
+        # with_hints=False,
+        # local_device=local_device,
+        # )
 
-        nb_wrong += (
-            max_nb_mistakes_on_one_grid(quizzes, result) >= args.nb_mistakes_to_be_wrong
-        ).long()
+        nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long()
 
     to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
         nb_wrong >= args.nb_have_to_be_wrong
     )
 
+    # print("\n\n", nb_correct, nb_wrong)
+
     return to_keep, nb_correct, nb_wrong
 
 
@@ -659,7 +670,7 @@ def evaluate_quizzes(quizzes, models, local_device):
 
 def identity_quizzes(quizzes):
     quizzes = quizzes.reshape(quizzes.size(0), 4, -1)
-    return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values & (
+    return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values | (
         quizzes[:, 2] == quizzes[:, 3]
     ).min(dim=1).values