return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
-def ae_predict(model, imt_set, local_device=main_device, desc="predict"):
+def ae_predict(model, imt_set, local_device=main_device):
model.eval().to(local_device)
record = []
- src = imt_set.split(args.eval_batch_size)
-
- if desc is not None:
- src = tqdm.tqdm(
- src,
- dynamic_ncols=True,
- desc=desc,
- total=imt_set.size(0) // args.eval_batch_size,
- )
+ src = tqdm.tqdm(
+ imt_set.split(args.eval_batch_size),
+ dynamic_ncols=True,
+ desc="predict",
+ total=imt_set.size(0) // args.eval_batch_size,
+ delay=10,
+ )
for imt in src:
# some paranoia
return torch.cat(record)
-def predict_full(
+def predict_the_four_grids(
model, input, with_noise=False, with_hints=False, local_device=main_device
):
input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
dynamic_ncols=True,
desc=label,
total=quizzes.size(0) // batch_size,
+ delay=10,
):
input, masks, targets = imt.unbind(dim=1)
if train and nb_samples % args.batch_size == 0:
for model in models:
model = copy.deepcopy(model).to(local_device).eval()
- result = predict_full(
+ predicted = predict_the_four_grids(
model=model,
input=quizzes,
with_noise=False,
with_hints=True,
local_device=local_device,
)
-
- nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, result)
+ nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, predicted)
nb_correct += (nb_mistakes == 0).long()
-
- # result = predict_full(
- # model=model,
- # input=quizzes,
- # with_noise=False,
- # with_hints=False,
- # local_device=local_device,
- # )
-
nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long()
to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
######################################################################
-c_quizzes = None
+main_c_quizzes = None
######################################################################
state = {
"current_epoch": n_epoch,
- "c_quizzes": c_quizzes,
+ "main_c_quizzes": main_c_quizzes,
}
filename = "state.pth"
lowest_test_accuracy = min([float(m.test_accuracy) for m in models])
if lowest_test_accuracy >= args.accuracy_to_make_c_quizzes:
- if c_quizzes is None:
+ if main_c_quizzes is None:
save_models(models, "naive")
nb_gpus = len(gpus)
log_string(f"generated_c_quizzes {new_c_quizzes.size()}")
- c_quizzes = (
+ main_c_quizzes = (
new_c_quizzes
- if c_quizzes is None
- else torch.cat([c_quizzes, new_c_quizzes])
+ if main_c_quizzes is None
+ else torch.cat([main_c_quizzes, new_c_quizzes])
)
- c_quizzes = c_quizzes[-args.nb_train_samples :]
+ main_c_quizzes = main_c_quizzes[-args.nb_train_samples :]
for model in models:
model.test_accuracy = 0
- if c_quizzes is None:
+ if main_c_quizzes is None:
log_string("no_c_quiz")
else:
- log_string(f"nb_c_quizzes {c_quizzes.size(0)}")
+ log_string(f"nb_c_quizzes {main_c_quizzes.size(0)}")
# --------------------------------------------------------------------
multithread_execution(
one_complete_epoch,
- [(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)],
+ [
+ (model, n_epoch, main_c_quizzes, gpu)
+ for model, gpu in zip(weakest_models, gpus)
+ ],
)
save_models(models)