- for primer in [
- 'red above green <sep> green top <sep> blue right of red <img>',
- 'there is red <sep> there is yellow <sep> there is blue <img>',
- 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
- 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
- ]:
+ np = torch.tensor(np)
+ count = torch.empty(np[:, 0].max() + 1, np[:, 2].max() + 1, dtype=torch.int64)
+ for i in range(count.size(0)):
+ for j in range(count.size(1)):
+ count[i, j] = ((np[:, 0] == i).long() * (np[:, 2] == j).long()).sum()
+
+ if generate_images:
+ img = [
+ picoclvr.descr2img(d, height=self.height, width=self.width)
+ for d in result_descr
+ ]
+
+ img = torch.cat(img, 0)
+ image_name = f"result_picoclvr_{n_epoch:04d}.png"
+ torchvision.utils.save_image(
+ img / 255.0, image_name, nrow=nb_per_primer, pad_value=0.8
+ )
+ log_string(f"wrote {image_name}")
+
+ return count
+
+ def produce_results(self, n_epoch, model):
+ primers_descr = [
+ "red above green <sep> green top <sep> blue right of red <img>",
+ "there is red <sep> there is yellow <sep> there is blue <img>",
+ "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>",
+ "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>",
+ ]
+
+ self.test_model(
+ n_epoch, model, primers_descr, nb_per_primer=8, generate_images=True
+ )