- def produce_results(self, n_epoch, model, nb_tokens = 50):
- img = [ ]
+ def generate(self, primer, model, nb_tokens):
+ t_primer = primer.strip().split(' ')
+ t_generated = [ ]
+
+ for j in range(nb_tokens):
+ t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
+ input = torch.tensor(t, device = self.device)
+ output = model(input)
+ logits = output[0, -1]
+ if args.synthesis_sampling:
+ dist = torch.distributions.categorical.Categorical(logits = logits)
+ t = dist.sample()
+ else:
+ t = logits.argmax()
+ t_generated.append(self.id2token[t.item()])
+
+ return ' '.join(t_primer + t_generated)
+
+ def produce_results(self, n_epoch, model, nb_tokens = None):
+ if nb_tokens is None:
+ nb_tokens = self.height * self.width + 3
+ descr = [ ]