From 197703b2e1d388a380634522dd2c52182d585028 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 22 Jul 2024 06:38:40 +0200 Subject: [PATCH] Update. --- quiz_machine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/quiz_machine.py b/quiz_machine.py index 5f14528..c73b6d0 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -73,7 +73,8 @@ def one_batch_masked_inplace_autoregression( logits = output[:, s] - logits = logit_transformer(s, logits).log_softmax(dim=-1) + if logit_transformer is not None: + logits = logit_transformer(s, logits).log_softmax(dim=-1) if deterministic_synthesis: t_next = logits.argmax(-1) -- 2.20.1