######################################################################
current_epoch = 0
+total_time_generating_c_quizzes = 0
+total_time_training_models = 0
if args.resume:
for model in models:
state = torch.load(os.path.join(args.result_dir, filename))
log_string(f"successfully loaded {filename}")
current_epoch = state["current_epoch"]
+ total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
+ total_time_training_models = state["total_time_training_models"]
except FileNotFoundError:
log_string(f"cannot find {filename}")
pass
args.nb_new_c_quizzes_for_test = 10
######################################################################
-######################################################################
class Folder(nn.Module):
######################################################################
for n_epoch in range(current_epoch, args.nb_epochs):
- state = {"current_epoch": n_epoch}
+ state = {
+ "current_epoch": n_epoch,
+ "total_time_training_models": total_time_training_models,
+ "total_time_generating_c_quizzes": total_time_generating_c_quizzes,
+ }
filename = "state.pth"
torch.save(state, os.path.join(args.result_dir, filename))
log_string(f"wrote {filename}")
model.best_dict = copy.deepcopy(model.state_dict())
model.best_test_accuracy = model.test_accuracy
- if min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
+ # we restart
+ if total_time_generating_c_quizzes == 0:
+ total_time_training_models = 0
+
+ if (
+ min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes
+ and total_time_training_models > total_time_generating_c_quizzes
+ ):
for model in models:
model.current_dict = copy.deepcopy(model.state_dict())
model.load_state_dict(model.best_dict)
+ start_time = time.perf_counter()
record_new_c_quizzes(
models,
quiz_machine,
args.nb_new_c_quizzes_for_train,
args.nb_new_c_quizzes_for_test,
)
+ total_time_generating_c_quizzes += time.perf_counter() - start_time
# Force one epoch of training
for model in models:
##################################################
# Select, improve, and eval the worst model(s)
- ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
+ if total_time_training_models <= total_time_generating_c_quizzes:
+ ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
- weakest_models = ranked_models[: len(gpus)]
+ weakest_models = ranked_models[: len(gpus)]
- threads = []
+ threads = []
- for gpu, model in zip(gpus, weakest_models):
- log_string(f"training model {model.id}")
+ start_time = time.perf_counter()
+ for gpu, model in zip(gpus, weakest_models):
+ log_string(f"training model {model.id}")
- t = threading.Thread(
- target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
- )
+ t = threading.Thread(
+ target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
+ )
+
+ threads.append(t)
- threads.append(t)
+ t.start()
- t.start()
+ for t in threads:
+ t.join()
- for t in threads:
- t.join()
+ total_time_training_models += time.perf_counter() - start_time
# Save the models to disk
- for model in weakest_models:
+ for model in models:
filename = f"gpt_{model.id:03d}.pth"
torch.save(
{