+ def test_model(
+ self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False
+ ):
+ nb_tokens_to_generate = self.height * self.width + 3
+ result_descr = []
+
+ for primer_descr in primers_descr:
+
+ results = autoregression(
+ model,
+ self.batch_size,
+ nb_samples=nb_per_primer,
+ nb_tokens_to_generate=nb_tokens_to_generate,
+ primer=self.tensorize([primer_descr]).expand(nb_per_primer, -1),
+ device=self.device,
+ )
+
+ l = [" ".join([self.id2token[t.item()] for t in r]) for r in results]
+ result_descr += l
+
+ np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
+
+ nb_requested_properties, _, nb_missing_properties = zip(*np)
+
+ log_string(
+ f"nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}"
+ )
+
+ np = torch.tensor(np)
+ count = torch.empty(np[:, 0].max() + 1, np[:, 2].max() + 1, dtype=torch.int64)
+ for i in range(count.size(0)):
+ for j in range(count.size(1)):
+ count[i, j] = ((np[:, 0] == i).long() * (np[:, 2] == j).long()).sum()
+
+ if generate_images:
+ img = [
+ picoclvr.descr2img(d, height=self.height, width=self.width)
+ for d in result_descr
+ ]
+
+ img = torch.cat(img, 0)
+ image_name = f"result_picoclvr_{n_epoch:04d}.png"
+ torchvision.utils.save_image(
+ img / 255.0, image_name, nrow=nb_per_primer, pad_value=0.8
+ )
+ log_string(f"wrote {image_name}")
+
+ return count
+