- for j in range(nb_tokens):
- t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
- input = torch.tensor(t, device = self.device)
- output = model(input)
- logits = output[0, -1]
- if args.synthesis_sampling:
- dist = torch.distributions.categorical.Categorical(logits = logits)
- t = dist.sample()
- else:
- t = logits.argmax()
- t_generated.append(self.id2token[t.item()])
+ l = [" ".join([self.id2token[t.item()] for t in r]) for r in results]
+ result_descr += l
+
+ np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
+
+ nb_requested_properties, _, nb_missing_properties = zip(*np)
+
+ log_string(
+ f"nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}"
+ )
+
+ np = torch.tensor(np)
+ count = torch.empty(np[:, 0].max() + 1, np[:, 2].max() + 1, dtype=torch.int64)
+ for i in range(count.size(0)):
+ for j in range(count.size(1)):
+ count[i, j] = ((np[:, 0] == i).long() * (np[:, 2] == j).long()).sum()
+
+ if generate_images:
+ img = [
+ picoclvr.descr2img(d, height=self.height, width=self.width)
+ for d in result_descr
+ ]
+
+ img = torch.cat(img, 0)
+ image_name = f"result_picoclvr_{n_epoch:04d}.png"
+ torchvision.utils.save_image(
+ img / 255.0, image_name, nrow=nb_per_primer, pad_value=0.8
+ )
+ log_string(f"wrote {image_name}")
+
+ return count
+
+ def produce_results(self, n_epoch, model):
+ primers_descr = [
+ "red above green <sep> green top <sep> blue right of red <img>",
+ "there is red <sep> there is yellow <sep> there is blue <img>",
+ "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>",
+ "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>",
+ ]
+
+ self.test_model(
+ n_epoch, model, primers_descr, nb_per_primer=8, generate_images=True
+ )
+
+ # FAR TOO SLOW!!!
+
+ # test_primers_descr=[ s.split('<img>')[0] for s in self.test_descr ]
+
+ # count=self.test_model(
+ # n_epoch, model,
+ # test_primers_descr,
+ # nb_per_primer=1, generate_images=False
+ # )