From 0c93da3f4b3761bc43593836492fda3d72444c1b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 23 Jul 2024 23:53:34 +0200 Subject: [PATCH] Update. --- grids.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/grids.py b/grids.py index a158c27..e64cb33 100755 --- a/grids.py +++ b/grids.py @@ -1419,10 +1419,20 @@ class Grids(problem.Problem): break def task_addition(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[: 3 + 1] + 1 + c = torch.randperm(len(self.colors) - 1)[:4] + 1 for X, f_X in [(A, f_A), (B, f_B)]: N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item() - N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item() + N2 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item() + S = N1 + N2 + for j in range(self.width): + r1 = (N1 // (2**j)) % 2 + X[0, -j - 1] = c[r1] + f_X[0, -j - 1] = c[r1] + r2 = (N2 // (2**j)) % 2 + X[1, -j - 1] = c[r2] + f_X[1, -j - 1] = c[r2] + rs = (S // (2**j)) % 2 + f_X[2, -j - 1] = c[2 + rs] # end_tasks @@ -1526,7 +1536,7 @@ if __name__ == "__main__": # nb, nrow = 8, 2 # for t in grids.all_tasks: - for t in [grids.task_fill]: + for t in [grids.task_reconfigure]: # for t in [grids.task_symbols]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) @@ -1535,7 +1545,7 @@ if __name__ == "__main__": "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow ) - # exit(0) + exit(0) nb = 1000 -- 2.39.5