Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 05:53:54 +0000 (07:53 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 05:53:54 +0000 (07:53 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 00a6cd1..cd78959 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -969,6 +969,7 @@ def test_ae(local_device=main_device):
             targets = input
 
             input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
+
             result = (1 - mask_generate) * input + mask_generate * torch.randint(
                 quiz_machine.problem.nb_colors, input.size(), device=input.device
             )
@@ -985,8 +986,8 @@ def test_ae(local_device=main_device):
                 result[not_converged] = update[not_converged]
                 not_converged = (pred_result != result).max(dim=1).values
                 nb_it += 1
-                print("DEBUG", nb_it, i.long().sum().item())
-                if not i.any() or nb_it > 100:
+                print("DEBUG", nb_it, not_converged.long().sum().item())
+                if not not_converged.any() or nb_it > 100:
                     break
 
             correct = (result == targets).min(dim=1).values.long()
index 0f13964..af24c92 100755 (executable)
@@ -81,7 +81,7 @@ class QuizMachine:
         self.answer_len = None
         self.prompt_noise = prompt_noise
 
-        # struct, quad_generate, quad_noise, quad_loss
+        # quad_order, quad_generate, quad_noise, quad_loss
         self.train_structures = [
             (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
             (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
@@ -171,32 +171,34 @@ class QuizMachine:
         else:
             quizzes = self.problem.generate_w_quizzes(nb_samples)
 
+        # shuffle
+
         i = torch.randperm(quizzes.size(0), device=quizzes.device)
         quizzes = quizzes[i]
 
-        self.randomize_configuations_inplace(
-            quizzes, quad_orders=[s for s, _, _, _ in data_structures]
-        )
+        # Re-order and inject noise
 
         quiz_mask_generate = quizzes.new_full(quizzes.size(), 1)
         quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
-
-        for quad_order, quad_generate, quad_noise, quad_loss in data_structures:
-            i = self.problem.indices_select(quizzes=quizzes, quad_order=quad_order)
-            if i.any():
-                if self.prompt_noise > 0.0:
-                    quizzes[i] = self.problem.inject_noise(
-                        quizzes[i],
-                        self.prompt_noise,
-                        quad_order=quad_order,
-                        quad_noise=quad_noise,
-                    )
-                quiz_mask_generate[i] = self.make_quiz_mask(
-                    quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate
-                )
-                quiz_mask_loss[i] = self.make_quiz_mask(
-                    quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_loss
+        order_ids = torch.randint(len(data_structures), (quizzes.size(0),))
+
+        for j, s in enumerate(data_structures):
+            quad_order, quad_generate, quad_noise, quad_loss = s
+            i = order_ids == j
+            quizzes[i] = self.problem.reconfigure(quizzes[i], quad_order=quad_order)
+            if self.prompt_noise > 0.0:
+                quizzes[i] = self.problem.inject_noise(
+                    quizzes[i],
+                    self.prompt_noise,
+                    quad_order=quad_order,
+                    quad_noise=quad_noise,
                 )
+            quiz_mask_generate[i] = self.make_quiz_mask(
+                quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate
+            )
+            quiz_mask_loss[i] = self.make_quiz_mask(
+                quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_loss
+            )
 
         return quizzes, quiz_mask_generate, quiz_mask_loss