"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
+ "world": {
+ "model": "37M",
+ "batch_size": 25,
+ "nb_train_samples": 50000,
+ "nb_test_samples": 10000,
+ },
"byheart": {
"model": "37M",
"batch_size": 25,
)
args.max_percents_of_test_in_train = -1
+elif args.task == "world":
+ task = tasks.World(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.physical_batch_size,
+ result_dir=args.result_dir,
+ logger=log_string,
+ device=device,
+ )
+ args.max_percents_of_test_in_train = -1
+
elif args.task == "learnop":
task = tasks.SandBox(
problem=problems.ProblemLearnOperator(),
time_pred_result = None
-for n_epoch in range(nb_epochs_finished, args.nb_epochs):
- learning_rate = learning_rate_schedule[n_epoch]
+######################################################################
+
+def one_epoch(model, task, learning_rate):
log_string(f"learning_rate {learning_rate}")
if args.optim == "sgd":
if nb_train_samples % args.batch_size == 0:
optimizer.step()
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+ log_string(f"train_perplexity {n_epoch} {train_perplexity}")
+
+
+######################################################################
+
+
+def run_tests(model, task, deterministic_synthesis):
with torch.autograd.no_grad():
model.eval()
input = input.to(device)
bs = model(mygpt.BracketedSequence(input))
- output_ar = bs.x
+ output = bs.x
loss = F.cross_entropy(output.transpose(1, 2), input)
nb_test_samples += input.size(0)
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
- test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
-
- log_string(
- f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
- )
-
- task.produce_results(
+ main_test_accuracy = task.produce_results(
n_epoch=n_epoch,
model=model,
result_dir=args.result_dir,
logger=log_string,
- deterministic_synthesis=args.deterministic_synthesis,
+ deterministic_synthesis=deterministic_synthesis,
)
- time_current_result = datetime.datetime.now()
- if time_pred_result is not None:
- log_string(
- f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+ test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+
+ log_string(f"test_perplexity {n_epoch} {test_perplexity}")
+
+ return main_test_accuracy
+
+
+######################################################################
+
+for n_epoch in range(nb_epochs_finished, args.nb_epochs):
+ learning_rate = learning_rate_schedule[n_epoch]
+
+ one_epoch(model, task, learning_rate)
+
+ test_accuracy = run_tests(model, task, deterministic_synthesis=False)
+
+ # --------------------------------------------
+
+ if test_accuracy >= 0.8:
+ nb_for_train, nb_for_test = 1000, 100
+ kept = []
+
+ while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
+ new_quizzes, nb_correct = task.create_new_quizzes(
+ n_epoch=n_epoch,
+ result_dir=args.result_dir,
+ logger=log_string,
+ nb=nb_required,
+ model=model,
+ nb_runs=10,
)
- time_pred_result = time_current_result
+
+ to_keep = new_quizzes[torch.logical_and(nb_correct >= 8, nb_correct < 10)]
+ log_string(f"keep {to_keep.size(0)} quizzes")
+ kept.append(to_keep)
+
+ new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
+
+ task.store_new_quizzes(new_quizzes[:nb_for_train], train=True)
+ task.store_new_quizzes(new_quizzes[nb_for_train:], train=False)
+
+ task.save_image(
+ new_quizzes[:96],
+ args.result_dir,
+ f"world_new_{n_epoch:04d}.png",
+ log_string,
+ )
+
+ # --------------------------------------------
+
+ time_current_result = datetime.datetime.now()
+ if time_pred_result is not None:
+ log_string(
+ f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+ )
+ time_pred_result = time_current_result
+
+ # --------------------------------------------
checkpoint = {
"nb_epochs_finished": n_epoch + 1,