From 23c14b8c000e9286dd76dda9e86488501a133d7f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 11 Sep 2024 11:25:52 +0200 Subject: [PATCH] Update. --- main.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 9e1726a..b83cabd 100755 --- a/main.py +++ b/main.py @@ -886,11 +886,7 @@ def run_ae_test( c_quizzes=c_quizzes, desc="test", ): - result = ae_generate( - model, - (1 - mask_generate) * x_0, - mask_generate, - ) + result = ae_generate(model, (1 - mask_generate) * x_0, mask_generate) correct = (result == x_0).min(dim=1).values.long() predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[ :, :, 1 @@ -1057,6 +1053,26 @@ def quiz_validation( nb_hints=0, nb_runs=1, ): + if c_quizzes.size(0) > args.inference_batch_size: + record = [] + for q in c_quizzes.split(args.inference_batch_size): + record.append( + quiz_validation( + models, + q, + local_device, + nb_have_to_be_correct, + nb_have_to_be_wrong, + nb_mistakes_to_be_wrong, + nb_hints=0, + nb_runs=1, + ) + ) + + return (torch.cat([tk for tk, _ in record], dim=0)), ( + torch.cat([w for _, w in record], dim=0) + ) + record_wrong = [] nb_correct, nb_wrong = 0, 0 @@ -1086,7 +1102,7 @@ def quiz_validation( result = ae_generate( model=model, - x_0=(1 - mask_generate) * c_quizzes, + x_0=c_quizzes, mask_generate=mask_generate, mask_hints=mask_hints, ) -- 2.39.5