c_quizzes=c_quizzes,
desc="test",
):
- result = ae_generate(
- model,
- (1 - mask_generate) * x_0,
- mask_generate,
- )
+ result = ae_generate(model, (1 - mask_generate) * x_0, mask_generate)
correct = (result == x_0).min(dim=1).values.long()
predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[
:, :, 1
nb_hints=0,
nb_runs=1,
):
+ if c_quizzes.size(0) > args.inference_batch_size:
+ record = []
+ for q in c_quizzes.split(args.inference_batch_size):
+ record.append(
+ quiz_validation(
+ models,
+ q,
+ local_device,
+ nb_have_to_be_correct,
+ nb_have_to_be_wrong,
+ nb_mistakes_to_be_wrong,
+ nb_hints=0,
+ nb_runs=1,
+ )
+ )
+
+ return (torch.cat([tk for tk, _ in record], dim=0)), (
+ torch.cat([w for _, w in record], dim=0)
+ )
+
record_wrong = []
nb_correct, nb_wrong = 0, 0
result = ae_generate(
model=model,
- x_0=(1 - mask_generate) * c_quizzes,
+ x_0=c_quizzes,
mask_generate=mask_generate,
mask_hints=mask_hints,
)