log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
- model.main_test_accuracy = quiz_machine.produce_results(
+ model.test_accuracy = quiz_machine.produce_results(
n_epoch=n_epoch,
model=model,
input=full_input[:2000],
else:
model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
- model.main_test_accuracy = 0.0
+ model.test_accuracy = 0.0
models.append(model)
d = torch.load(os.path.join(args.result_dir, filename))
model.load_state_dict(d["state_dict"])
model.optimizer.load_state_dict(d["optimizer_state_dict"])
- model.main_test_accuracy = d["main_test_accuracy"]
+ model.test_accuracy = d["test_accuracy"]
+ model.best_test_accuracy = d["best_test_accuracy"]
+ model.best_dict = d["best_dict"]
model.train_c_quiz_bags = d["train_c_quiz_bags"]
model.test_c_quiz_bags = d["test_c_quiz_bags"]
log_string(f"successfully loaded {filename}")
log_string(f"--- epoch {n_epoch} ----------------------------------------")
- cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
+ cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
log_string(f"current_test_accuracies {cta}")
##################################################
# If all the models are good enough, generate new quizzes and
# re-compute the test errors
- if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
+ for model in models:
+ if model.test_accuracy >= model.best_test_accuracy:
+ model.best_dict = copy.deepcopy(model.state_dict())
+ model.best_test_accuracy = model.test_accuracy
+
+ if min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
+ for model in models:
+ model.current_dict = copy.deepcopy(model.state_dict())
+ model.load_state_dict(model.best_dict)
+
record_new_c_quizzes(
models,
quiz_machine,
# Force one epoch of training
for model in models:
- model.main_test_accuracy = 0.0
+ model.load_state_dict(model.current_dict)
##################################################
# Select, improve, and eval the worst model(s)
- ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
+ ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
weakest_models = ranked_models[: len(gpus)]
{
"state_dict": model.state_dict(),
"optimizer_state_dict": model.optimizer.state_dict(),
- "main_test_accuracy": model.main_test_accuracy,
+ "test_accuracy": model.test_accuracy,
+ "best_test_accuracy": model.best_test_accuracy,
+ "best_dict": model.best_dict,
"train_c_quiz_bags": model.train_c_quiz_bags,
"test_c_quiz_bags": model.test_c_quiz_bags,
},