X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=mygpt.py;h=a17849181edafb878c6c55cc45ad0422b01c1dce;hb=8a548630c88957264306db4354e880414b0fa8ef;hp=3bb3519ac1d70bf38ccb8fe6c88c106c8b1d0e2b;hpb=17c63771f2ca82ce39d8406e377ace2015fe69fc;p=culture.git diff --git a/mygpt.py b/mygpt.py index 3bb3519..a178491 100755 --- a/mygpt.py +++ b/mygpt.py @@ -292,7 +292,7 @@ class MyGPT(nn.Module): ) # Needed to initialize the model's cache for s in range(to_generate.min(), to_generate.max() + 1): output = self(BracketedSequence(input, s, 1)).x - logits = output[:, s] + logits = output[:, s] / temperature if forbidden_tokens is not None: logits = logits.masked_fill(forbidden_tokens, float("-inf")) if forced_biases is not None: @@ -304,7 +304,7 @@ class MyGPT(nn.Module): t_next = dist.sample() sum_logits += logits.log_softmax(dim=-1)[ torch.arange(t_next.size(0)), t_next - ] + ].sum() input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] return sum_logits