From cf33212d343aa636233aff827041aaa2cc0c205a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 23 Jul 2024 15:59:18 +0200 Subject: [PATCH] Update. --- grids.py | 8 +++++++- quiz_machine.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/grids.py b/grids.py index 406c0b7..a158c27 100755 --- a/grids.py +++ b/grids.py @@ -1418,6 +1418,12 @@ class Grids(problem.Problem): if accept_full or (d * (X == 0)).max() == self.height * self.width: break + def task_addition(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[: 3 + 1] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item() + N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item() + # end_tasks ###################################################################### @@ -1529,7 +1535,7 @@ if __name__ == "__main__": "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow ) - exit(0) + # exit(0) nb = 1000 diff --git a/quiz_machine.py b/quiz_machine.py index 4b07de3..2d38fab 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -370,7 +370,7 @@ class QuizMachine: ###################################################################### def renew_train_w_quizzes(self, model, p2a_only=False): - if for_train and hasattr(model, "hard_w_quizzes"): + if hasattr(model, "hard_w_quizzes"): self.logger( f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}" ) -- 2.20.1