Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 23 Jul 2024 13:59:18 +0000 (15:59 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 23 Jul 2024 13:59:18 +0000 (15:59 +0200)
grids.py
quiz_machine.py

index 406c0b7..a158c27 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -1418,6 +1418,12 @@ class Grids(problem.Problem):
                 if accept_full or (d * (X == 0)).max() == self.height * self.width:
                     break
 
+    def task_addition(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[: 3 + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
+            N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
+
     # end_tasks
 
     ######################################################################
@@ -1529,7 +1535,7 @@ if __name__ == "__main__":
             "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
         )
 
-    exit(0)
+    exit(0)
 
     nb = 1000
 
index 4b07de3..2d38fab 100755 (executable)
@@ -370,7 +370,7 @@ class QuizMachine:
     ######################################################################
 
     def renew_train_w_quizzes(self, model, p2a_only=False):
-        if for_train and hasattr(model, "hard_w_quizzes"):
+        if hasattr(model, "hard_w_quizzes"):
             self.logger(
                 f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
             )