Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 25 Jul 2024 04:10:02 +0000 (06:10 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 25 Jul 2024 04:10:02 +0000 (06:10 +0200)
grids.py
main.py
quiz_machine.py

index f6129e9..93b027a 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -137,6 +137,7 @@ class Grids(problem.Problem):
         self.check_structure(quizzes, struct)
         return struct
 
+    # What a mess
     def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
         if torch.is_tensor(quizzes):
             return self.reconfigure([quizzes], struct=struct)[0]
@@ -165,11 +166,11 @@ class Grids(problem.Problem):
 
         return result
 
-    def non_trivial(self, quizzes):
+    def trivial(self, quizzes):
         S = self.height * self.width
         assert self.check_structure(quizzes, struct=("A", "f_A", "B", "f_B"))
         a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
-        return (a[:, 0] == a[:, 1]).min(dim=1).values & (a[:, 2] == a[:, 3]).min(
+        return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min(
             dim=1
         ).values
 
diff --git a/main.py b/main.py
index fa33b4e..257f40f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -454,7 +454,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         # We discard the trivial ones, according to a criterion
         # specific to the world quizzes (e.g. B=f(B))
 
-        c_quizzes = c_quizzes[quiz_machine.problem.non_trivial(c_quizzes)]
+        c_quizzes = c_quizzes[quiz_machine.problem.trivial(c_quizzes) == False]
 
         # We go through nb_rounds rounds and keep only quizzes on
         # which
index 4615e3a..2ca584e 100755 (executable)
@@ -270,14 +270,12 @@ class QuizMachine:
 
     ######################################################################
 
-    def randomize_configuations_inplace(self, quizzes, configurations):
-        r = torch.randint(
-            len(configurations), (quizzes.size(0),), device=quizzes.device
-        )
+    def randomize_configuations_inplace(self, quizzes, structs):
+        r = torch.randint(len(structs), (quizzes.size(0),), device=quizzes.device)
 
-        for c in range(len(configurations)):
+        for c in range(len(structs)):
             quizzes[r == c] = self.problem.reconfigure(
-                quizzes[r == c], struct=configurations[c]
+                quizzes[r == c], struct=structs[c]
             )
 
     def create_w_quizzes(self, model, nb_train_samples, nb_test_samples):
@@ -285,11 +283,11 @@ class QuizMachine:
         model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples)
 
         self.randomize_configuations_inplace(
-            model.train_w_quizzes, configurations=self.train_struct
+            model.train_w_quizzes, structs=self.train_struct
         )
 
         self.randomize_configuations_inplace(
-            model.test_w_quizzes, configurations=self.train_struct
+            model.test_w_quizzes, structs=self.train_struct
         )
 
     ######################################################################
@@ -322,7 +320,7 @@ class QuizMachine:
             )
 
         self.randomize_configuations_inplace(
-            model.train_w_quizzes, configurations=self.train_struct
+            model.train_w_quizzes, structs=self.train_struct
         )
 
     ######################################################################