From 674eb2f0d02b362fbfcf8ed403b2caa329054d0a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 23 Jun 2024 23:07:05 +0200 Subject: [PATCH] Update. --- mygpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mygpt.py b/mygpt.py index 3bb3519..7117e76 100755 --- a/mygpt.py +++ b/mygpt.py @@ -304,7 +304,7 @@ class MyGPT(nn.Module): t_next = dist.sample() sum_logits += logits.log_softmax(dim=-1)[ torch.arange(t_next.size(0)), t_next - ] + ].sum() input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] return sum_logits -- 2.39.5