X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=mygpt.py;h=70478493f100584588e7389b465303054c62af1f;hb=8e37a868ac7dfc1cb5e924790929c6eebabbeb94;hp=ab4ccbc1b2168c27419e5810d07bd55d2c48665b;hpb=35a16ac34a3f1af05323a9cb3823fbcfd74035a4;p=culture.git diff --git a/mygpt.py b/mygpt.py index ab4ccbc..7047849 100755 --- a/mygpt.py +++ b/mygpt.py @@ -279,7 +279,7 @@ class MyGPT(nn.Module): self, input, ar_mask, - summed_logits, + seq_logproba, temperature=1.0, deterministic_synthesis=False, forbidden_tokens=None, @@ -309,10 +309,9 @@ class MyGPT(nn.Module): else: dist = torch.distributions.categorical.Categorical(logits=logits) t_next = dist.sample() - if summed_logits is not None: - summed_logits += logits[torch.arange(t_next.size(0)), t_next].sum( - dim=-1 - ) + + all_n = torch.arange(t_next.size(0)) + seq_logproba += logits[all_n, t_next].sum(dim=-1) input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]