X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=70478493f100584588e7389b465303054c62af1f;hb=a2346746c9b417eaf97aad87ed31dea92c3bb887;hp=809f79032a55204abd8add007b025ca54b1ad227;hpb=60d829ba77c9769009d3d5a93a50d23c532d019a;p=culture.git diff --git a/mygpt.py b/mygpt.py index 809f790..7047849 100755 --- a/mygpt.py +++ b/mygpt.py @@ -310,9 +310,8 @@ class MyGPT(nn.Module): dist = torch.distributions.categorical.Categorical(logits=logits) t_next = dist.sample() - if seq_logproba is not None: - all_t = torch.arange(t_next.size(0)) - seq_logproba += logits[all_t, t_next].sum(dim=-1) + all_n = torch.arange(t_next.size(0)) + seq_logproba += logits[all_n, t_next].sum(dim=-1) input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]