if total_time_generating_c_quizzes == 0:
total_time_training_models = 0
- if (
- min([m.gen_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes
- and total_time_training_models >= total_time_generating_c_quizzes
- ):
- ######################################################################
- # Re-initalize if there are enough culture quizzes
-
+ if min([m.gen_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
if args.reboot:
- nb_c_quizzes_per_model = [
- sum([x.size(0) for x in model.train_c_quiz_bags]) for model in models
- ]
+ for model in models:
+ model.current_dict = copy.deepcopy(model.state_dict())
+ model.load_state_dict(model.gen_state_dict)
+
+ while True:
+ record_new_c_quizzes(
+ models,
+ quiz_machine,
+ args.nb_new_c_quizzes_for_train,
+ args.nb_new_c_quizzes_for_test,
+ )
- p = tuple(
- f"{(x*100)/args.nb_train_samples:.02f}%" for x in nb_c_quizzes_per_model
- )
+ nb_c_quizzes_per_model = [
+ sum([x.size(0) for x in model.train_c_quiz_bags])
+ for model in models
+ ]
- log_string(f"nb_c_quizzes_per_model {p}")
+ p = tuple(
+ f"{(x*100)/args.nb_train_samples:.02f}%"
+ for x in nb_c_quizzes_per_model
+ )
- m = max(nb_c_quizzes_per_model)
+ log_string(f"nb_c_quizzes_per_model {p}")
- if m >= args.nb_train_samples:
- model = models[nb_c_quizzes_per_model.index(m)]
- common_c_quiz_bags.append(torch.cat(model.train_c_quiz_bags, dim=0))
- nb_common_c_quizzes = sum([x.size(0) for x in common_c_quiz_bags])
- log_string(
- f"rebooting the models with {nb_common_c_quizzes} culture quizzes"
- )
+ m = max(nb_c_quizzes_per_model)
- models = create_models()
- total_time_generating_c_quizzes = 0
- total_time_training_models = 0
+ if m >= args.nb_train_samples:
+ break
- for model in models:
- model.current_dict = copy.deepcopy(model.state_dict())
- model.load_state_dict(model.gen_state_dict)
+ model = models[nb_c_quizzes_per_model.index(m)]
+ common_c_quiz_bags.append(torch.cat(model.train_c_quiz_bags, dim=0))
+ nb_common_c_quizzes = sum([x.size(0) for x in common_c_quiz_bags])
+ log_string(
+ f"rebooting the models with {nb_common_c_quizzes} culture quizzes"
+ )
- start_time = time.perf_counter()
+ models = create_models()
+ total_time_generating_c_quizzes = 0
+ total_time_training_models = 0
- record_new_c_quizzes(
- models,
- quiz_machine,
- args.nb_new_c_quizzes_for_train,
- args.nb_new_c_quizzes_for_test,
- )
+ elif total_time_training_models >= total_time_generating_c_quizzes:
+ for model in models:
+ model.current_dict = copy.deepcopy(model.state_dict())
+ model.load_state_dict(model.gen_state_dict)
- total_time_generating_c_quizzes += time.perf_counter() - start_time
+ start_time = time.perf_counter()
- for model in models:
- model.load_state_dict(model.current_dict)
+ record_new_c_quizzes(
+ models,
+ quiz_machine,
+ args.nb_new_c_quizzes_for_train,
+ args.nb_new_c_quizzes_for_test,
+ )
+
+ total_time_generating_c_quizzes += time.perf_counter() - start_time
+
+ for model in models:
+ model.load_state_dict(model.current_dict)
##################################################
# Select, improve, and eval the worst model(s)