- def generate(self, descr_primer, model, nb_tokens):
- results = autoregression(
- model, self.batch_size,
- 1, nb_tokens, primer = descr2tensor(descr_primer),
- device = self.device
- )
- return ' '.join([ self.id2token[t.item()] for t in results.flatten() ])
-
- def produce_results(self, n_epoch, model, nb_tokens = None):
- if nb_tokens is None:
- nb_tokens = self.height * self.width + 3
- result_descr = [ ]
- nb_per_primer = 8
-
- for descr_primer in [
- 'red above green <sep> green top <sep> blue right of red <img>',
- 'there is red <sep> there is yellow <sep> there is blue <img>',
- 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
- 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
- ]:
-
- for k in range(nb_per_primer):
- result_descr.append(self.generate(descr_primer, model, nb_tokens))
-
- img = [ picoclvr.descr2img(d, height = self.height, width = self.width)
- for d in result_descr ]
- img = torch.cat(img, 0)
- image_name = f'result_picoclvr_{n_epoch:04d}.png'
- torchvision.utils.save_image(
- img / 255.,
- image_name, nrow = nb_per_primer, pad_value = 0.8
+ def test_model(
+ self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False
+ ):
+ nb_tokens_to_generate = self.height * self.width + 3
+ result_descr = []
+
+ for primer_descr in primers_descr:
+
+ results = autoregression(
+ model,
+ self.batch_size,
+ nb_samples=nb_per_primer,
+ nb_tokens_to_generate=nb_tokens_to_generate,
+ primer=self.tensorize([primer_descr]).expand(nb_per_primer, -1),
+ device=self.device,
+ )
+
+ l = [" ".join([self.id2token[t.item()] for t in r]) for r in results]
+ result_descr += l
+
+ np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
+
+ nb_requested_properties, _, nb_missing_properties = zip(*np)
+
+ log_string(
+ f"nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}"