From c0e6db6a39da18c186d9c8c5abc106f442da0f95 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 12 Aug 2024 16:56:58 +0200 Subject: [PATCH] Update. --- main.py | 17 ++++++++++------- quiz_machine.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index fbebbb9..e516a77 100755 --- a/main.py +++ b/main.py @@ -667,13 +667,16 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 ) #!!!!!!!!!!!!!!!!!!!! - l = quiz_machine.models_logprobas( - models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - for s in range(seq_logproba.size(0)): - print(f"-- {s=} ----------------") - for m in range(seq_logproba.size(1)): - print("DEBUG", seq_logproba[s, m].item(), l[s, m].item()) + for m in range(seq_logproba.size(1)): + l = quiz_machine.models_logprobas( + [models[m]], + solved_c_quizzes[:, m, :], + ("A", "f_A", "B", "f_B"), + (0, 0, 0, 1), + (0, 0, 0, 0), + ) + for s in range(seq_logproba.size(0)): + print("DEBUG", seq_logproba[s, m].item(), l[s, 0].item()) exit(0) #!!!!!!!!!!!!!!!!!!!!!!!!! diff --git a/quiz_machine.py b/quiz_machine.py index 6aa4e9b..b2287b8 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -53,7 +53,7 @@ def one_batch_masked_inplace_autoregression( all_n = torch.arange(t_next.size(0)) - acc_seq_logproba += ar_mask[:, s] * logits[all_n, t_next] + acc_seq_logproba += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next] input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] -- 2.39.5