Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 10:38:26 +0000 (12:38 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 10:38:26 +0000 (12:38 +0200)
attae.py
grids.py
main.py
quiz_machine.py

index a9bdeba..1e5e122 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -180,8 +180,8 @@ class FunctionalAttentionAE(AttentionAE):
             a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
             n = self.nb_work_tokens
             s = (q.size(2) - n) // 2
-            a[:, :, n + 0 * s : n + 1 * s, n + 0 * s : n + 1 * s] = float("-inf")
-            a[:, :, n + 1 * s : n + 2 * s, n + 1 * s : n + 2 * s] = float("-inf")
+            a[:, :, n + 1 * s : n + 2 * s, n + 0 * s : n + 1 * s] = float("-inf")
+            a[:, :, n + 0 * s : n + 1 * s, n + 1 * s : n + 2 * s] = float("-inf")
             a = a.softmax(dim=3)
             y = torch.einsum("nhts,nhsd->nhtd", a, v)
             return y
index 23a3d12..4254b32 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -134,24 +134,28 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations):
 
 
 class Grids(problem.Problem):
-    # grid_gray=64
+    grid_gray = 64
+    thickness = 1
+    background_gray = 255
+
+    # grid_gray=240
     # thickness=1
-    # background_gray=255
+    # background_gray=240
 
-    grid_gray = 255
-    thickness = 0
-    background_gray = grid_gray
+    grid_gray = 255
+    thickness = 0
+    # background_gray = 240
 
     named_colors = [
         ("white", [background_gray, background_gray, background_gray]),
         # ("white", [224, 224, 224]),
         ("red", [255, 0, 0]),
-        ("green", [0, 192, 0]),
+        ("green", [0, 160, 0]),
         ("blue", [0, 0, 255]),
         ("yellow", [255, 224, 0]),
         ("cyan", [0, 255, 255]),
         ("violet", [224, 128, 255]),
-        ("lightgreen", [192, 255, 192]),
+        ("lightgreen", [160, 255, 160]),
         ("brown", [165, 42, 42]),
         ("lightblue", [192, 192, 255]),
         ("gray", [128, 128, 128]),
diff --git a/main.py b/main.py
index d903693..750d1b1 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -54,7 +54,7 @@ parser.add_argument("--nb_train_samples", type=int, default=50000)
 
 parser.add_argument("--nb_test_samples", type=int, default=1000)
 
-parser.add_argument("--nb_c_quizzes", type=int, default=10000)
+parser.add_argument("--nb_c_quizzes", type=int, default=5000)
 
 parser.add_argument("--c_quiz_multiplier", type=int, default=1)
 
@@ -64,7 +64,7 @@ parser.add_argument("--nb_have_to_be_correct", type=int, default=3)
 
 parser.add_argument("--nb_have_to_be_wrong", type=int, default=1)
 
-parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=10)
+parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5)
 
 # ----------------------------------
 
@@ -198,7 +198,8 @@ if args.seed >= 0:
 
 
 def log_string(s):
-    """print the given string prefixed with a time stamps, and log it into log_file is not None"""
+    """print the given string prefixed with a time stamps, and log it
+    into log_file is not None"""
 
     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
 
@@ -301,10 +302,9 @@ log_string(f"vocabulary_size {vocabulary_size}")
 
 ######################################################################
 
-# If we need to move an optimizer to a different device
-
 
 def optimizer_to(optim, device):
+    """Move the optimizer optim to the device"""
     for param in optim.state.values():
         # Not sure there are any global tensors in the state dict
         if isinstance(param, torch.Tensor):
@@ -322,11 +322,12 @@ def optimizer_to(optim, device):
 ######################################################################
 
 
-# Make args.nb_hints holes in the mask and copy the corresponding cell
-# values from the target to the input
-
-
 def add_hints_imt(imt_set):
+    """Set every component of the mask to zero with probability
+    args.proba_hint, and for each component set to zero, copy the
+    corresponding value from the target into the input
+
+    """
     input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
     # h = torch.rand(masks.size(), device=masks.device) - masks
     # t = h.sort(dim=1).values[:, args.nb_hints, None]
@@ -339,11 +340,9 @@ def add_hints_imt(imt_set):
     return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
 
-# Make pixels from the available input (mask=0) noise with probability
-# args.proba_prompt_noise
-
-
 def add_noise_imt(imt_set):
+    """Replace every component of the input by a random value with
+    probability args.proba_prompt_noise."""
     input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
     noise = quiz_machine.pure_noise(input.size(0), input.device)
     change = (1 - masks) * (
@@ -633,8 +632,8 @@ import attae
 models = []
 
 for i in range(args.nb_models):
-    model = attae.FunctionalAttentionAE(
-        # model = attae.AttentionAE(
+    model = attae.FunctionalAttentionAE(
+    model = attae.AttentionAE(
         vocabulary_size=vocabulary_size * 2,
         dim_model=args.dim_model,
         dim_keys=args.dim_keys,
@@ -655,7 +654,7 @@ for i in range(args.nb_models):
 ######################################################################
 
 
-def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device):
+def evaluate_quizzes(quizzes, models, with_perturbations, local_device):
     nb_correct, nb_wrong = 0, 0
 
     for model in models:
@@ -663,7 +662,7 @@ def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device):
         result = predict_full(
             model=model,
             input=quizzes,
-            with_perturbations=True,
+            with_perturbations=with_perturbations,
             local_device=local_device,
         )
         nb_mistakes = (result != quizzes).long().sum(dim=1)
@@ -680,6 +679,26 @@ def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device):
 ######################################################################
 
 
+def remove_old_problematic(c_quizzes, models, nb_to_remove, local_device):
+    nb_removed = 0
+    for input in c_quizzes.split(args.eval_batch_size):
+        _, nb_correct, nb_wrong = evaluate_quizzes(
+            quizzes=input,
+            models=models,
+            with_perturbations=False,
+            local_device=local_device,
+        )
+
+        to_remove = nb_wrong > 0
+        nb_removed += to_remove.long().sum()
+
+        if nb_removed >= nb_to_remove:
+            break
+
+
+######################################################################
+
+
 def identity_quizzes(quizzes):
     quizzes = quizzes.reshape(quizzes.size(0), 4, -1)
     return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values & (
@@ -714,7 +733,7 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
             to_keep, nb_correct, nb_wrong = evaluate_quizzes(
                 quizzes=c_quizzes,
                 models=models,
-                fraction_with_hints=1.0,
+                with_perturbations=True,
                 local_device=local_device,
             )
 
@@ -760,7 +779,7 @@ def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
     to_keep, nb_correct, nb_wrong = evaluate_quizzes(
         quizzes=c_quizzes,
         models=models,
-        fraction_with_hints=0,
+        with_perturbations=False,
         local_device=local_device,
     )
 
index e2f6d3b..72f1d16 100755 (executable)
@@ -206,6 +206,8 @@ class QuizMachine:
             quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape(
                 quizzes.size(0), -1
             )
+            nb_w_quizzes = quizzes.size(0)
+            nb_c_quizzes = 0
         else:
             if c_quiz_multiplier > 1:
                 n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
@@ -229,10 +231,14 @@ class QuizMachine:
                 w_quizzes.size(0), -1
             )
             quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
+            nb_w_quizzes = w_quizzes.size(0)
+            nb_c_quizzes = c_quizzes.size(0)
 
         i = torch.randperm(quizzes.size(0), device=quizzes.device)
         quizzes = quizzes[i].contiguous()
 
+        logger(f"quiz_set nb_w_quizzes {nb_w_quizzes} nb_c_quizzes {nb_c_quizzes}")
+
         return quizzes
 
     ######################################################################