- file_name = f'result_wiki103_{n_epoch:04d}.txt'
-
- with open(file_name, 'w') as outfile:
- for primer in [
- 'the cat is hunting a',
- 'paris is the capital',
- 'cars are convenient',
- 'the difference between men and women is',
- 'the object was blue all over and green all over it was',
- 'cherries are red and lemons are',
- 'cherries are sweet and lemons are',
- 'two plus three equals',
- 'deep learning is',
- ]:
- t_primer = self.tokenizer(primer)
- t_generated = [ ]
-
- for j in range(nb_tokens):
-
- input = self.tensorize([ t_primer + t_generated ]).to(self.device)
- input = F.pad(input, (0, 1)) # Add the next token, the one to predict
- output = model(input)
- logits = output[0, -1]
- if args.synthesis_sampling:
- dist = torch.distributions.categorical.Categorical(logits = logits)
- t_next = dist.sample()
- else:
- t_next = logits.argmax()
- t_generated.append(self.vocab.lookup_token(t_next))
- if t_generated[-1] == '<nul>': break
-
- s = ' '.join(t_generated)
-
- outfile.write(f'<{primer}> {s}\n')
-
- log_string(f'wrote {file_name}')
+ file_name = f"result_wiki103_{n_epoch:04d}.txt"
+
+ with open(file_name, "w") as outfile:
+ for primer in [
+ "the cat is hunting a",
+ "paris is the capital",
+ "cars are convenient",
+ "the difference between men and women is",
+ "the object was blue all over and green all over it was",
+ "cherries are red and lemons are",
+ "cherries are sweet and lemons are",
+ "two plus three equals",
+ "deep learning is",
+ ]:
+ t_primer = self.tokenizer(primer)
+ t_generated = []
+
+ for j in range(nb_tokens):
+
+ input = self.tensorize([t_primer + t_generated]).to(self.device)
+ input = F.pad(
+ input, (0, 1)
+ ) # Add the next token, the one to predict
+ output = model(input)
+ logits = output[0, -1]
+ if args.deterministic_synthesis:
+ t_next = logits.argmax()
+ else:
+ dist = torch.distributions.categorical.Categorical(
+ logits=logits
+ )
+ t_next = dist.sample()
+ t_generated.append(self.vocab.lookup_token(t_next))
+ if t_generated[-1] == "<nul>":
+ break
+
+ s = " ".join(t_generated)
+
+ outfile.write(f"<{primer}> {s}\n")
+
+ log_string(f"wrote {file_name}")
+