+ test_accuracy = run_tests(model, task, deterministic_synthesis=False)
+
+ # --------------------------------------------
+
+ if test_accuracy >= 0.8:
+ nb_for_train, nb_for_test = 1000, 100
+ kept = []
+
+ while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
+ new_quizzes, nb_correct = task.create_new_quizzes(
+ n_epoch=n_epoch,
+ result_dir=args.result_dir,
+ logger=log_string,
+ nb=nb_required,
+ model=model,
+ nb_runs=10,
+ )
+
+ to_keep = new_quizzes[torch.logical_and(nb_correct >= 8, nb_correct < 10)]
+ log_string(f"keep {to_keep.size(0)} quizzes")
+ kept.append(to_keep)
+
+ new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
+
+ task.store_new_quizzes(new_quizzes[:nb_for_train], train=True)
+ task.store_new_quizzes(new_quizzes[nb_for_train:], train=False)
+
+ task.save_image(
+ new_quizzes[:96],
+ args.result_dir,
+ f"world_new_{n_epoch:04d}.png",
+ log_string,
+ )