# ----------------------------------
-parser.add_argument("--nb_gpts", type=int, default=5)
-
-parser.add_argument("--min_succeed_to_validate", type=int, default=2)
-
-parser.add_argument("--max_fail_to_validate", type=int, default=3)
+parser.add_argument("--nb_gpts", type=int, default=2)
parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
nb_samples_accumulated = 0
full_input, full_mask_loss = quiz_machine.data_input(
- args.nb_test_samples, model.test_c_quiz_bags
+ args.nb_test_samples, test_c_quiz_bags
)
+
src = zip(
full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
)
log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
- input, _ = quiz_machine.data_input(2000, model.test_c_quiz_bags)
+ input, _ = quiz_machine.data_input(1000, test_c_quiz_bags)
model.test_accuracy = quiz_machine.produce_results(
n_epoch=n_epoch,
nb_train_samples, acc_train_loss = 0, 0.0
full_input, full_mask_loss = quiz_machine.data_input(
- args.nb_train_samples, model.train_c_quiz_bags
+ args.nb_train_samples, train_c_quiz_bags
)
src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
return l.exp()
-def create_c_quizzes(main_model, other_models, quiz_machine, nb_for_train, nb_for_test):
+def create_c_quizzes(
+ main_model,
+ other_models,
+ quiz_machine,
+ nb_for_train,
+ train_c_quiz_bags,
+ nb_for_test,
+ test_c_quiz_bags,
+):
nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
e = "???"
log_string(
- f"keep c_quizzes model {model_for_generation.id} validated nb_validated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h) proportion_kept {nb_validated * 100 / nb_generated:.02f}%"
+ f"keep c_quizzes model {main_model.id} validated nb_validated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h) proportion_kept {nb_validated * 100 / nb_generated:.02f}%"
)
# Save some images
args.result_dir, filename, c_quizzes[:128], comments=comments
)
-
-log_string(
- f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}"
-)
+ log_string(
+ f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in train_c_quiz_bags ])} test {sum([q.size(0) for q in test_c_quiz_bags ])}"
+ )
######################################################################
)
model.id = k
- model.train_c_quiz_bags = []
- model.test_c_quiz_bags = []
if args.schedule_free:
model.optimizer = schedulefree.AdamWScheduleFree(
######################################################################
+train_c_quiz_bags = []
+test_c_quiz_bags = []
+
current_epoch = 0
if args.resume:
model.load_state_dict(d["state_dict"])
model.optimizer.load_state_dict(d["optimizer_state_dict"])
model.test_accuracy = d["test_accuracy"]
- model.train_c_quiz_bags = d["train_c_quiz_bags"]
- model.test_c_quiz_bags = d["test_c_quiz_bags"]
log_string(f"successfully loaded {filename}")
except FileNotFoundError:
log_string(f"cannot find {filename}")
state = torch.load(os.path.join(args.result_dir, filename))
log_string(f"successfully loaded {filename}")
current_epoch = state["current_epoch"]
+ train_c_quiz_bags = d["train_c_quiz_bags"]
+ test_c_quiz_bags = d["test_c_quiz_bags"]
except FileNotFoundError:
log_string(f"cannot find {filename}")
pass
######################################################################
if args.nb_new_c_quizzes_for_train is None:
- args.nb_new_c_quizzes_for_train = args.nb_train_samples // 250
+ args.nb_new_c_quizzes_for_train = args.nb_train_samples
if args.nb_new_c_quizzes_for_test is None:
- args.nb_new_c_quizzes_for_test = args.nb_test_samples // 250
+ args.nb_new_c_quizzes_for_test = args.nb_test_samples
log_string(
f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
for n_epoch in range(current_epoch, args.nb_epochs):
state = {
"current_epoch": n_epoch,
+ "train_c_quiz_bags": train_c_quiz_bags,
+ "test_c_quiz_bags": test_c_quiz_bags,
}
filename = "state.pth"
torch.save(state, os.path.join(args.result_dir, filename))
##################################################
if min([m.test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
- record_new_c_quizzes(
- models,
- quiz_machine,
- args.nb_new_c_quizzes_for_train,
- args.nb_new_c_quizzes_for_test,
+ create_c_quizzes(
+ main_model=models[0],
+ other_models=models[1:],
+ quiz_machine=quiz_machine,
+ nb_for_train=args.nb_new_c_quizzes_for_train,
+ train_c_quiz_bags=train_c_quiz_bags,
+ nb_for_test=args.nb_new_c_quizzes_for_test,
+ test_c_quiz_bags=test_c_quiz_bags,
)
for model in models:
).to(main_device)
model.load_state_dict(new_model.state_dict())
model.test_accuracy = 0.0
- model.best_test_accuracy = 0.0
- model.best_dict = copy.deepcopy(model.state_dict())
##################################################
# Select, improve, and eval the worst model(s)
# This ugly recipe will pick the worst if there some below
# args.accuracy_to_make_c_quizzes or one at random if they
# are all above
- key=lambda m: float(
- m.test_accuracy
- if m.test_accuracy < args.accuracy_to_make_c_quizzes
- else args.accuracy_to_make_c_quizzes + torch.rand(1).item()
- ),
+ key=lambda m: float(m.test_accuracy),
)
weakest_models = ranked_models[: len(gpus)]
for t in threads:
t.join()
- total_time_training_models += time.perf_counter() - start_time
-
for model in weakest_models:
save_additional_results(n_epoch, model, models, c_quizzes_procedure)
"state_dict": model.state_dict(),
"optimizer_state_dict": model.optimizer.state_dict(),
"test_accuracy": model.test_accuracy,
- "best_test_accuracy": model.best_test_accuracy,
- "best_dict": model.best_dict,
- "train_c_quiz_bags": model.train_c_quiz_bags,
- "test_c_quiz_bags": model.test_c_quiz_bags,
},
os.path.join(args.result_dir, filename),
)