From 32ca97e128baa0392684fd5a84632651178b3d89 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 23 Jun 2024 15:19:55 +0200 Subject: [PATCH] Update. --- main.py | 4 +++- tasks.py | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 3b29d01..683c07d 100755 --- a/main.py +++ b/main.py @@ -419,6 +419,8 @@ for n_epoch in range(args.nb_epochs): # improve it one_epoch(model, task) + task.renew_samples(args.nb_train_samples // args.nb_gpts) + log_string( f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}" ) @@ -426,7 +428,7 @@ for n_epoch in range(args.nb_epochs): # test it run_tests(model, task, deterministic_synthesis=False) - if model.main_test_accuracy >= accuracy_to_make_quizzes: + if min([m.main_test_accuracy for m in models]) >= accuracy_to_make_quizzes: other_models = models.copy() other_models.remove(model) diff --git a/tasks.py b/tasks.py index 27173e1..2c88333 100755 --- a/tasks.py +++ b/tasks.py @@ -223,6 +223,14 @@ class World(Task): return main_test_accuracy + def renew_samples(self, nb, for_train=True): + input = self.train_input if for_train else self.test_input + nb = min(nb, input.size(0)) + input[:-nb] = input[nb:].clone() + input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to( + self.device + ) + def store_new_quizzes(self, new_quizzes, for_train=True): if for_train: self.train_quizzes.append(new_quizzes) -- 2.20.1