From 69742f27abf892f27452e3b2e31f65af236dfa2e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 12 Aug 2024 16:44:46 +0200 Subject: [PATCH] Update. --- main.py | 12 ++++++------ quiz_machine.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index f9dc35d..b23a52b 100755 --- a/main.py +++ b/main.py @@ -503,14 +503,14 @@ def model_transformer_cold(model): c_quizzes_procedure = [ - # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot), - # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold), - # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), - # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold), + (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold), # Generate the full thing at high temp - (("B", "f_B", "A", "f_A"), (1, 1, 1, 1), model_transformer_hot), + # (("B", "f_B", "A", "f_A"), (1, 1, 1, 1), model_transformer_hot), # Fix A and B - (("f_B", "B", "f_A", "A"), (0, 1, 0, 1), model_transformer_cold), + # (("f_B", "B", "f_A", "A"), (0, 1, 0, 1), model_transformer_cold), # Fix f_B # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), # Fix f_A diff --git a/quiz_machine.py b/quiz_machine.py index 92da03d..5bab1e5 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -53,7 +53,7 @@ def one_batch_masked_inplace_autoregression( all_n = torch.arange(t_next.size(0)) - seq_logproba += logits[all_n, t_next] + seq_logproba += logits.log_softmax(dim=1)[all_n, t_next] input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] -- 2.39.5