Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 23 Jun 2024 13:19:55 +0000 (15:19 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 23 Jun 2024 13:19:55 +0000 (15:19 +0200)
main.py
tasks.py

diff --git a/main.py b/main.py
index 3b29d01..683c07d 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -419,6 +419,8 @@ for n_epoch in range(args.nb_epochs):
     # improve it
     one_epoch(model, task)
 
+    task.renew_samples(args.nb_train_samples // args.nb_gpts)
+
     log_string(
         f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
     )
@@ -426,7 +428,7 @@ for n_epoch in range(args.nb_epochs):
     # test it
     run_tests(model, task, deterministic_synthesis=False)
 
-    if model.main_test_accuracy >= accuracy_to_make_quizzes:
+    if min([m.main_test_accuracy for m in models]) >= accuracy_to_make_quizzes:
         other_models = models.copy()
         other_models.remove(model)
 
index 27173e1..2c88333 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -223,6 +223,14 @@ class World(Task):
 
         return main_test_accuracy
 
+    def renew_samples(self, nb, for_train=True):
+        input = self.train_input if for_train else self.test_input
+        nb = min(nb, input.size(0))
+        input[:-nb] = input[nb:].clone()
+        input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
+            self.device
+        )
+
     def store_new_quizzes(self, new_quizzes, for_train=True):
         if for_train:
             self.train_quizzes.append(new_quizzes)