return targets, logits
-def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
+def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device):
with torch.autograd.no_grad():
model.eval().to(local_device)
args.nb_test_samples,
data_structures,
local_device,
- "test",
+ c_quizzes=c_quizzes,
+ desc="test",
):
targets, logits = targets_and_prediction(model, input, mask_generate)
loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
args.nb_test_samples,
data_structures,
local_device,
+ c_quizzes,
"test",
):
targets = input.clone()
f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
)
- run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device)
+ run_ae_test(model, quiz_machine, n_epoch, c_quizzes, local_device=local_device)
######################################################################