From cf33212d343aa636233aff827041aaa2cc0c205a Mon Sep 17 00:00:00 2001
From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= <francois@fleuret.org>
Date: Tue, 23 Jul 2024 15:59:18 +0200
Subject: [PATCH] Update.

---
 grids.py        | 8 +++++++-
 quiz_machine.py | 2 +-
 2 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/grids.py b/grids.py
index 406c0b7..a158c27 100755
--- a/grids.py
+++ b/grids.py
@@ -1418,6 +1418,12 @@ class Grids(problem.Problem):
                 if accept_full or (d * (X == 0)).max() == self.height * self.width:
                     break
 
+    def task_addition(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[: 3 + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
+            N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
+
     # end_tasks
 
     ######################################################################
@@ -1529,7 +1535,7 @@ if __name__ == "__main__":
             "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
         )
 
-    exit(0)
+    # exit(0)
 
     nb = 1000
 
diff --git a/quiz_machine.py b/quiz_machine.py
index 4b07de3..2d38fab 100755
--- a/quiz_machine.py
+++ b/quiz_machine.py
@@ -370,7 +370,7 @@ class QuizMachine:
     ######################################################################
 
     def renew_train_w_quizzes(self, model, p2a_only=False):
-        if for_train and hasattr(model, "hard_w_quizzes"):
+        if hasattr(model, "hard_w_quizzes"):
             self.logger(
                 f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
             )
-- 
2.39.5