- def produce_results(self, n_epoch, model, nb_tokens = 50):
- img = [ ]
- nb_per_primer = 8
-
- for primer in [
- 'red above green <sep> green top <sep> blue right of red <img>',
- 'there is red <sep> there is yellow <sep> there is blue <img>',
- 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
- 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
- ]:
-
- for k in range(nb_per_primer):
- t_primer = primer.strip().split(' ')
- t_generated = [ ]
-
- for j in range(nb_tokens):
- t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
- input = torch.tensor(t, device = self.device)
- output = model(input)
- logits = output[0, -1]
- if args.synthesis_sampling:
- dist = torch.distributions.categorical.Categorical(logits = logits)
- t = dist.sample()
- else:
- t = logits.argmax()
- t_generated.append(self.id2token[t.item()])
-
- descr = [ ' '.join(t_primer + t_generated) ]
- img += [ picoclvr.descr2img(descr) ]
-
- img = torch.cat(img, 0)
- file_name = f'result_picoclvr_{n_epoch:04d}.png'
- torchvision.utils.save_image(img / 255.,
- file_name, nrow = nb_per_primer, pad_value = 0.8)
- log_string(f'wrote {file_name}')
+ def test_model(self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False):
+ nb_tokens_to_generate = self.height * self.width + 3
+ result_descr = [ ]
+
+ for primer_descr in primers_descr:
+
+ results = autoregression(
+ model,
+ self.batch_size,
+ nb_samples = nb_per_primer,
+ nb_tokens_to_generate = nb_tokens_to_generate,
+ primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1),
+ device = self.device
+ )
+
+ l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ]
+ result_descr += l
+
+ np = picoclvr.nb_properties(
+ result_descr,
+ height = self.height, width = self.width
+ )
+
+ nb_requested_properties, _, nb_missing_properties = zip(*np)
+
+ log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
+
+ np=torch.tensor(np)
+ count=torch.empty(np[:,0].max()+1,np[:,2].max()+1,dtype=torch.int64)
+ for i in range(count.size(0)):
+ for j in range(count.size(1)):
+ count[i,j]=((np[:,0]==i).long()*(np[:,2]==j).long()).sum()
+
+ if generate_images:
+ img = [
+ picoclvr.descr2img(d, height = self.height, width = self.width)
+ for d in result_descr
+ ]
+
+ img = torch.cat(img, 0)
+ image_name = f'result_picoclvr_{n_epoch:04d}.png'
+ torchvision.utils.save_image(
+ img / 255.,
+ image_name, nrow = nb_per_primer, pad_value = 0.8
+ )
+ log_string(f'wrote {image_name}')
+
+ return count
+
+ def produce_results(self, n_epoch, model):
+ primers_descr = [
+ 'red above green <sep> green top <sep> blue right of red <img>',
+ 'there is red <sep> there is yellow <sep> there is blue <img>',
+ 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
+ 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
+ ]
+
+ self.test_model(
+ n_epoch, model,
+ primers_descr,
+ nb_per_primer=8, generate_images=True
+ )
+
+ # FAR TOO SLOW!!!
+
+ # test_primers_descr=[ s.split('<img>')[0] for s in self.test_descr ]
+
+ # count=self.test_model(
+ # n_epoch, model,
+ # test_primers_descr,
+ # nb_per_primer=1, generate_images=False
+ # )
+
+ # with open(f'perf_{n_epoch:04d}.txt', 'w') as f:
+ # for i in range(count.size(0)):
+ # for j in range(count.size(1)):
+ # f.write(f'{count[i,j]}')
+ # f.write(" " if j<count.size(1)-1 else "\n")