many_colors = many_colors
)
+ self.test_descr = descr[:nb // 5]
+ self.train_descr = descr[nb // 5:]
+
descr = [ s.strip().split(' ') for s in descr ]
l = max([ len(s) for s in descr ])
descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
def vocabulary_size(self):
return len(self.token2id)
+ def generate(self, primer, model, nb_tokens):
+ t_primer = primer.strip().split(' ')
+ t_generated = [ ]
+
+ for j in range(nb_tokens):
+ t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
+ input = torch.tensor(t, device = self.device)
+ output = model(input)
+ logits = output[0, -1]
+ if args.synthesis_sampling:
+ dist = torch.distributions.categorical.Categorical(logits = logits)
+ t = dist.sample()
+ else:
+ t = logits.argmax()
+ t_generated.append(self.id2token[t.item()])
+
+ return ' '.join(t_primer + t_generated)
+
def produce_results(self, n_epoch, model, nb_tokens = 50):
- img = [ ]
+ descr = [ ]
nb_per_primer = 8
for primer in [
]:
for k in range(nb_per_primer):
- t_primer = primer.strip().split(' ')
- t_generated = [ ]
-
- for j in range(nb_tokens):
- t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
- input = torch.tensor(t, device = self.device)
- output = model(input)
- logits = output[0, -1]
- if args.synthesis_sampling:
- dist = torch.distributions.categorical.Categorical(logits = logits)
- t = dist.sample()
- else:
- t = logits.argmax()
- t_generated.append(self.id2token[t.item()])
-
- descr = [ ' '.join(t_primer + t_generated) ]
- img += [ picoclvr.descr2img(descr) ]
+ descr.append(self.generate(primer, model, nb_tokens))
+ img = [ picoclvr.descr2img(d) for d in descr ]
img = torch.cat(img, 0)
file_name = f'result_picoclvr_{n_epoch:04d}.png'
torchvision.utils.save_image(img / 255.,
file_name, nrow = nb_per_primer, pad_value = 0.8)
log_string(f'wrote {file_name}')
+ log_string(f'nb_misssing {picoclvr.nb_missing_properties(descr)}')
+
######################################################################
class TaskWiki103(Task):