From 6917d3d52a4b473d31121a471ab98fa114bdb1a6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 21 Jun 2024 21:27:44 +0200 Subject: [PATCH] Update. --- main.py | 4 +++- tasks.py | 25 +++++++++++++++---------- world.py | 2 +- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index 22edf7b..4a1207d 100755 --- a/main.py +++ b/main.py @@ -845,6 +845,7 @@ def run_tests(model, task, deterministic_synthesis): def create_quizzes( + model, other_models, task, nb_for_train=1000, @@ -861,7 +862,8 @@ def create_quizzes( result_dir=args.result_dir, logger=log_string, nb=4 * (nb_for_train + nb_for_test), - models=other_models, + model=model, + other_models=other_models, nb_runs=nb_runs, ) diff --git a/tasks.py b/tasks.py index 622cd56..0345bd0 100755 --- a/tasks.py +++ b/tasks.py @@ -2244,7 +2244,9 @@ class World(Task): input[:nb_kept] = input[-nb_kept:].clone() input[nb_kept:] = new_quizzes - def create_new_quizzes(self, n_epoch, result_dir, logger, nb, model, nb_runs): + def create_new_quizzes( + self, n_epoch, result_dir, logger, nb, models, other_models, nb_runs + ): new_quizzes = torch.empty( nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 ) @@ -2274,15 +2276,18 @@ class World(Task): .expand_as(result) ) - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis=False, - progress_bar_desc=None, - device=self.device, - ) + dispatch = torch.randint(len(other_models), (result.size(0),)) + + for n, m in enumerate(other_models): + masked_inplace_autoregression( + m, + self.batch_size, + result[dispatch == n], + ar_mask[dispatch == n], + deterministic_synthesis=False, + progress_bar_desc=None, + device=self.device, + ) nb_correct = ( (input == result).long().min(dim=-1).values.reshape(-1, nb_runs).sum(dim=-1) diff --git a/world.py b/world.py index 97c7b1d..43126d5 100755 --- a/world.py +++ b/world.py @@ -34,7 +34,7 @@ def generate( nb, height, width, - max_nb_obj=colors.size(0) - 2, + max_nb_obj=2, nb_iterations=2, ): f_start = torch.zeros(nb, height, width, dtype=torch.int64) -- 2.39.5