X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=809f79032a55204abd8add007b025ca54b1ad227;hb=60d829ba77c9769009d3d5a93a50d23c532d019a;hp=ab4ccbc1b2168c27419e5810d07bd55d2c48665b;hpb=35a16ac34a3f1af05323a9cb3823fbcfd74035a4;p=culture.git diff --git a/mygpt.py b/mygpt.py index ab4ccbc..809f790 100755 --- a/mygpt.py +++ b/mygpt.py @@ -279,7 +279,7 @@ class MyGPT(nn.Module): self, input, ar_mask, - summed_logits, + seq_logproba, temperature=1.0, deterministic_synthesis=False, forbidden_tokens=None, @@ -309,10 +309,10 @@ class MyGPT(nn.Module): else: dist = torch.distributions.categorical.Categorical(logits=logits) t_next = dist.sample() - if summed_logits is not None: - summed_logits += logits[torch.arange(t_next.size(0)), t_next].sum( - dim=-1 - ) + + if seq_logproba is not None: + all_t = torch.arange(t_next.size(0)) + seq_logproba += logits[all_t, t_next].sum(dim=-1) input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]