# improve it
one_epoch(model, task)
+ task.renew_samples(args.nb_train_samples // args.nb_gpts)
+
log_string(
f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
)
# test it
run_tests(model, task, deterministic_synthesis=False)
- if model.main_test_accuracy >= accuracy_to_make_quizzes:
+ if min([m.main_test_accuracy for m in models]) >= accuracy_to_make_quizzes:
other_models = models.copy()
other_models.remove(model)
return main_test_accuracy
+ def renew_samples(self, nb, for_train=True):
+ input = self.train_input if for_train else self.test_input
+ nb = min(nb, input.size(0))
+ input[:-nb] = input[nb:].clone()
+ input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
+ self.device
+ )
+
def store_new_quizzes(self, new_quizzes, for_train=True):
if for_train:
self.train_quizzes.append(new_quizzes)