nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.physical_batch_size,
+ result_dir=args.result_dir,
logger=log_string,
device=device,
)
train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
- log_string(f"train)perplexity {n_epoch} {train_perplexity}")
+ log_string(f"train_perplexity {n_epoch} {train_perplexity}")
######################################################################
nb_test_samples += input.size(0)
- task.produce_results(
+ main_test_accuracy = task.produce_results(
n_epoch=n_epoch,
model=model,
result_dir=args.result_dir,
)
test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
- log_string(f"test)perplexity {n_epoch} {test_perplexity}")
+
+ log_string(f"test_perplexity {n_epoch} {test_perplexity}")
+
+ return main_test_accuracy
######################################################################
one_epoch(model, task, learning_rate)
- run_tests(model, task, deterministic_synthesis=True)
+ test_accuracy = run_tests(model, task, deterministic_synthesis=False)
+
+ # --------------------------------------------
+
+ if test_accuracy >= 0.8:
+ nb_for_train, nb_for_test = 1000, 100
+ kept = []
+
+ while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
+ new_quizzes, nb_correct = task.create_new_quizzes(
+ n_epoch=n_epoch,
+ result_dir=args.result_dir,
+ logger=log_string,
+ nb=nb_required,
+ model=model,
+ nb_runs=10,
+ )
+
+ to_keep = new_quizzes[torch.logical_and(nb_correct >= 8, nb_correct < 10)]
+ log_string(f"keep {to_keep.size(0)} quizzes")
+ kept.append(to_keep)
+
+ new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
+
+ task.store_new_quizzes(new_quizzes[:nb_for_train], train=True)
+ task.store_new_quizzes(new_quizzes[nb_for_train:], train=False)
+
+ task.save_image(
+ new_quizzes[:96],
+ args.result_dir,
+ f"world_new_{n_epoch:04d}.png",
+ log_string,
+ )
# --------------------------------------------