From 9cbc374a4abb5bcb488391ffe9eec80750fdc67d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 21 Jul 2024 08:05:53 +0200 Subject: [PATCH] Update. --- grids.py | 77 ++++++++++++++++++++++++++++++++++++++++--------- main.py | 1 + quiz_machine.py | 8 ++--- 3 files changed, 68 insertions(+), 18 deletions(-) diff --git a/grids.py b/grids.py index bbb18d2..3c00abe 100755 --- a/grids.py +++ b/grids.py @@ -219,9 +219,9 @@ class Grids(problem.Problem): self.task_trajectory, self.task_bounce, self.task_scale, - # self.task_symbols, + self.task_symbols, self.task_isometry, - self.task_islands, + # self.task_islands, ] if tasks is None: @@ -430,7 +430,8 @@ class Grids(problem.Problem): while True: i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values - + i[:, 1] += 1 + j[:, 1] += 1 big_enough = ( (i[:, 1] >= i[:, 0] + min_height) & (j[:, 1] >= j[:, 0] + min_height) @@ -903,27 +904,28 @@ class Grids(problem.Problem): if d.min() > delta: break - for k in range(1, nb_rec): - X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k] - ai, aj = i.float().mean(), j.float().mean() q = torch.randint(3, (1,)).item() + 1 - X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0] - X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0] - X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0] - X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0] - assert i[q] != ai and j[q] != aj + for Z in [X, f_X]: + for k in range(1, nb_rec): + Z[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k] + Z[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0] + Z[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0] + Z[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0] + Z[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0] + + # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q] + f_X[i[0] + delta // 2, j[0] + delta // 2] = c[q] + X[ i[0] + delta // 2 + (i[q] - ai).sign().long(), j[0] + delta // 2 + (j[q] - aj).sign().long(), ] = c[nb_rec] - f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q] - # @torch.compile def task_isometry(self, A, f_A, B, f_B): nb_rec = 3 @@ -1271,6 +1273,52 @@ class Grids(problem.Problem): f_X[i, s : s + w1] = c1 f_X[i, s + w1 : s + w1 + w2] = c2 + # @torch.compile + # [ai1,ai2] [bi1,bi2] + def task_proximity(self, A, f_A, B, f_B): + def rec_dist(a, b): + ai1, aj1, ai2, aj2 = a + bi1, bj1, bi2, bj2 = b + v = max(ai1 - bi2, bi1 - ai2) + h = max(aj1 - bj2, bj1 - aj2) + return min(max(v, 0) + max(h + 1, 0), max(v + 1, 0) + max(h, 0)) + + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + r = self.rec_coo(nb_rec, prevent_overlap=True) + d = [rec_dist(r[0], r[k]) for k in range(nb_rec)] + if min(d[1:]) == 0: + break + + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + if d[n] == 0: + f_X[i1:i2, j1:j2] = c[0] + else: + f_X[i1:i2, j1:j2] = c[n] + + # @torch.compile + # [ai1,ai2] [bi1,bi2] + def task_corners(self, A, f_A, B, f_B): + polarity = torch.randint(2, (1,)).item() + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + r = self.rec_coo(nb_rec, prevent_overlap=True) + + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + if polarity == 0: + X[i1, j1] = c[n] + X[i2 - 1, j2 - 1] = c[n] + else: + X[i1, j2 - 1] = c[n] + X[i2 - 1, j1] = c[n] + f_X[i1:i2, j1:j2] = c[n] + ###################################################################### def trivial_prompts_and_answers(self, prompts, answers): @@ -1371,7 +1419,8 @@ if __name__ == "__main__": # nb, nrow = 8, 2 # for t in grids.all_tasks: - for t in [grids.task_convex]: + # for t in [grids.task_proximity, grids.task_corners]: + for t in [grids.task_symbols]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) # prompts[...] = torch.randint(grids.nb_token_values(), prompts.size()) diff --git a/main.py b/main.py index 5ce9731..b7d0431 100755 --- a/main.py +++ b/main.py @@ -343,6 +343,7 @@ def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_de model.main_test_accuracy = quiz_machine.produce_results( n_epoch=n_epoch, model=model, + input=full_input[:2000], result_dir=args.result_dir, deterministic_synthesis=deterministic_synthesis, ) diff --git a/quiz_machine.py b/quiz_machine.py index c006ea4..cbaa7cd 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -290,7 +290,9 @@ class QuizMachine: ###################################################################### - def produce_results(self, n_epoch, model, result_dir, deterministic_synthesis): + def produce_results( + self, n_epoch, model, input, result_dir, deterministic_synthesis + ): def compute_accuracy(input, log_prefix=None): input = input.to(self.device) ar_mask = self.problem.make_ar_mask(input, shape="fwd_3_bck_123") @@ -334,9 +336,7 @@ class QuizMachine: return result, correct - test_result, test_correct = compute_accuracy( - model.test_w_quizzes[:2000], log_prefix="test" - ) + test_result, test_correct = compute_accuracy(input, log_prefix="test") n_test_p2a = model.test_w_quizzes[:2000, 0] == self.problem.token_forward -- 2.39.5