nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.physical_batch_size,
+ result_dir=args.result_dir,
logger=log_string,
device=device,
)
nb_test_samples += input.size(0)
- task.produce_results(
+ main_test_accuracy = task.produce_results(
n_epoch=n_epoch,
model=model,
result_dir=args.result_dir,
)
test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+
log_string(f"test_perplexity {n_epoch} {test_perplexity}")
+ return main_test_accuracy
+
######################################################################
one_epoch(model, task, learning_rate)
- run_tests(model, task, deterministic_synthesis=False)
+ test_accuracy = run_tests(model, task, deterministic_synthesis=False)
# --------------------------------------------
- if n_epoch >= 3:
- nb_required = 1000
+ if test_accuracy >= 0.8:
+ nb_for_train, nb_for_test = 1000, 100
kept = []
- while sum([x.size(0) for x in kept]) < nb_required:
- new_problems, nb_correct = task.create_new_problems(
+ while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
+ new_quizzes, nb_correct = task.create_new_quizzes(
n_epoch=n_epoch,
result_dir=args.result_dir,
logger=log_string,
nb_runs=10,
)
- to_keep = new_problems[torch.logical_and(nb_correct >= 8, nb_correct < 10)]
- log_string(f"keep {to_keep.size(0)} problems")
+ to_keep = new_quizzes[torch.logical_and(nb_correct >= 8, nb_correct < 10)]
+ log_string(f"keep {to_keep.size(0)} quizzes")
kept.append(to_keep)
- new_problems = torch.cat(kept, dim=0)[:nb_required]
+ new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
+
+ task.store_new_quizzes(new_quizzes[:nb_for_train], train=True)
+ task.store_new_quizzes(new_quizzes[nb_for_train:], train=False)
+
+ task.save_image(
+ new_quizzes[:96],
+ args.result_dir,
+ f"world_new_{n_epoch:04d}.png",
+ log_string,
+ )
# --------------------------------------------