X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=b6eb6fefa15658ed36c89096a36a69c78db2902a;hb=3ae0c8f3767e4285ab548e4548576a6ddf6003bb;hp=ee44ebe9ed4e1416def886b44e333d15947ebd8d;hpb=0a8ed78035264cd7552b712596e897d6e73b7ef4;p=mygpt.git diff --git a/main.py b/main.py index ee44ebe..b6eb6fe 100755 --- a/main.py +++ b/main.py @@ -349,11 +349,11 @@ class TaskWiki103(Task): input = F.pad(input, (0, 1)) # Add the next token, the one to predict output = model(input) logits = output[0, -1] - if args.synthesis_sampling: + if args.deterministic_synthesis: + t_next = logits.argmax() + else: dist = torch.distributions.categorical.Categorical(logits = logits) t_next = dist.sample() - else: - t_next = logits.argmax() t_generated.append(self.vocab.lookup_token(t_next)) if t_generated[-1] == '': break