From e0ab20005e2578edff27d4246c6904cf1047ed22 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 6 Jul 2024 11:55:58 +0300 Subject: [PATCH] Update. --- grids.py | 60 +++++++++++++++++++++++++------------------------ main.py | 1 + quiz_machine.py | 7 +++--- 3 files changed, 36 insertions(+), 32 deletions(-) diff --git a/grids.py b/grids.py index 247c146..03da16d 100755 --- a/grids.py +++ b/grids.py @@ -581,34 +581,36 @@ class Grids(problem.Problem): def task_islands(self, A, f_A, B, f_B): for X, f_X in [(A, f_A), (B, f_B)]: + nb_on_border = 0 + for _ in range(10): + for k in torch.randperm(self.height * self.width): + i, j = k % self.height, k // self.height + border = ( + i == 0 or i == self.height - 1 or j == 0 or j == self.width - 1 + ) + no, nq, nq_diag = self.contact(X, i, j, 1) + + if ( + (nq > 0 and not border) + or (nq == 0 and border and nb_on_border < 4) + ) and nq_diag == 0: + X[i, j] = 1 + if border: + nb_on_border += 1 + while True: - i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) - if ( - i == 0 - or i == self.height - 1 - or j == 0 - or j == self.width - 1 - or X[i, j] == 1 - ): - break - while True: - di, dj = torch.randint(3, (2,)) - 1 - if abs(di) + abs(dj) > 0: - break - X[i, j] = 1 - while True: - i, j = i + di, j + dj - if i < 0 or i >= self.height or j < 0 or j >= self.width: - break - b = ( - i == 0 - or i == self.height - 1 - or j == 0 - or j == self.width - 1 - or X[i, j] == 1 - ) - X[i, j] = 1 - if b: + nb_fixes = 0 + for i in range(1, self.height - 1): + for j in range(1, self.width - 1): + if ( + X[i, j] == 1 + and X[i - 1, j] + X[i + 1, j] + X[i, j - 1] + X[i, j + 1] + == 1 + ): + X[i, j] = 0 + nb_fixes += 1 + + if nb_fixes == 0: break ###################################################################### @@ -681,8 +683,8 @@ if __name__ == "__main__": grids = Grids() - for t in grids.all_tasks(): - # for t in [grids.task_islands]: + # for t in grids.all_tasks(): + for t in [grids.task_islands]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t]) grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4) diff --git a/main.py b/main.py index bfbc0e4..6b46fa0 100755 --- a/main.py +++ b/main.py @@ -418,6 +418,7 @@ def create_c_quizzes( ) file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") + with open(file_name, "w") as logp_file: while ( valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0) diff --git a/quiz_machine.py b/quiz_machine.py index de1e8d1..cb187be 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -50,7 +50,8 @@ def one_batch_masked_inplace_autoregression( t_next = dist.sample() all_n = torch.arange(t_next.size(0)) - seq_logproba += logits[all_n, t_next].sum(dim=-1) + + seq_logproba += logits[all_n, t_next] input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] @@ -420,11 +421,11 @@ class QuizMachine: nb_correct = 0 + seq_logproba[...] = 0.0 + for model in models_for_validation: result = c_quizzes.clone() - seq_logproba[...] = 0.0 - ar_mask = self.make_ar_mask(result) masked_inplace_autoregression( -- 2.39.5