From: François Fleuret Date: Sun, 23 Jun 2024 13:19:55 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=32ca97e128baa0392684fd5a84632651178b3d89;p=culture.git Update. --- diff --git a/main.py b/main.py index 3b29d01..683c07d 100755 --- a/main.py +++ b/main.py @@ -419,6 +419,8 @@ for n_epoch in range(args.nb_epochs): # improve it one_epoch(model, task) + task.renew_samples(args.nb_train_samples // args.nb_gpts) + log_string( f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}" ) @@ -426,7 +428,7 @@ for n_epoch in range(args.nb_epochs): # test it run_tests(model, task, deterministic_synthesis=False) - if model.main_test_accuracy >= accuracy_to_make_quizzes: + if min([m.main_test_accuracy for m in models]) >= accuracy_to_make_quizzes: other_models = models.copy() other_models.remove(model) diff --git a/tasks.py b/tasks.py index 27173e1..2c88333 100755 --- a/tasks.py +++ b/tasks.py @@ -223,6 +223,14 @@ class World(Task): return main_test_accuracy + def renew_samples(self, nb, for_train=True): + input = self.train_input if for_train else self.test_input + nb = min(nb, input.size(0)) + input[:-nb] = input[nb:].clone() + input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to( + self.device + ) + def store_new_quizzes(self, new_quizzes, for_train=True): if for_train: self.train_quizzes.append(new_quizzes)