models.append(model)
+######################################################################
+
+current_epoch = 0
+
+if args.resume:
+ for model in models:
+ filename = f"ae_{model.id:03d}.pth"
+
+ try:
+ d = torch.load(os.path.join(args.result_dir, filename))
+ model.load_state_dict(d["state_dict"])
+ model.optimizer.load_state_dict(d["optimizer_state_dict"])
+ model.test_accuracy = d["test_accuracy"]
+ # model.gen_test_accuracy = d["gen_test_accuracy"]
+ # model.gen_state_dict = d["gen_state_dict"]
+ # model.train_c_quiz_bags = d["train_c_quiz_bags"]
+ # model.test_c_quiz_bags = d["test_c_quiz_bags"]
+ log_string(f"successfully loaded {filename}")
+ except FileNotFoundError:
+ log_string(f"cannot find {filename}")
+ pass
+
+ try:
+ filename = "state.pth"
+ state = torch.load(os.path.join(args.result_dir, filename))
+ log_string(f"successfully loaded {filename}")
+ current_epoch = state["current_epoch"]
+ # total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
+ # total_time_training_models = state["total_time_training_models"]
+ # common_c_quiz_bags = state["common_c_quiz_bags"]
+ except FileNotFoundError:
+ log_string(f"cannot find {filename}")
+ pass
+
+######################################################################
+
+nb_parameters = sum(p.numel() for p in models[0].parameters())
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+
+
+######################################################################
+
for n_epoch in range(args.nb_epochs):
+ state = {
+ "current_epoch": n_epoch,
+ # "total_time_generating_c_quizzes": total_time_generating_c_quizzes,
+ # "total_time_training_models": total_time_training_models,
+ # "common_c_quiz_bags": common_c_quiz_bags,
+ }
+ filename = "state.pth"
+ torch.save(state, os.path.join(args.result_dir, filename))
+ log_string(f"wrote {filename}")
+
+ log_string(f"--- epoch {n_epoch} ----------------------------------------")
+
+ cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
+ log_string(f"current_test_accuracies {cta}")
+
+ # --------------------------------------------------------------------
+
ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
weakest_models = ranked_models[: len(gpus)]
for t in threads:
t.join()
+ # --------------------------------------------------------------------
+
+ for model in models:
+ filename = f"ae_{model.id:03d}.pth"
+ torch.save(
+ {
+ "state_dict": model.state_dict(),
+ "optimizer_state_dict": model.optimizer.state_dict(),
+ "test_accuracy": model.test_accuracy,
+ # "gen_test_accuracy": model.gen_test_accuracy,
+ # "gen_state_dict": model.gen_state_dict,
+ # "train_c_quiz_bags": model.train_c_quiz_bags,
+ # "test_c_quiz_bags": model.test_c_quiz_bags,
+ },
+ os.path.join(args.result_dir, filename),
+ )
+ log_string(f"wrote {filename}")
+
######################################################################
######################################################################
-current_epoch = 0
-
# We balance the computing time between training the models and
# generating c_quizzes
total_time_generating_c_quizzes = 0
total_time_training_models = 0
+current_epoch = 0
+
if args.resume:
for model in models:
filename = f"gpt_{model.id:03d}.pth"