state = torch.load(os.path.join(args.result_dir, filename))
log_string(f"successfully loaded {filename}")
current_epoch = state["current_epoch"]
- total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
- total_time_training_models = state["total_time_training_models"]
+ # total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
+ # total_time_training_models = state["total_time_training_models"]
except FileNotFoundError:
log_string(f"cannot find {filename}")
pass
if (
min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes
- and total_time_training_models > total_time_generating_c_quizzes
+ and total_time_training_models >= total_time_generating_c_quizzes
):
for model in models:
model.current_dict = copy.deepcopy(model.state_dict())
##################################################
# Select, improve, and eval the worst model(s)
- if total_time_training_models <= total_time_generating_c_quizzes:
+ if total_time_training_models < total_time_generating_c_quizzes:
ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
weakest_models = ranked_models[: len(gpus)]
threads = []
start_time = time.perf_counter()
+
for gpu, model in zip(gpus, weakest_models):
log_string(f"training model {model.id}")