Update.
[culture.git] / tasks.py
index 7894fcd..b4e6f67 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -500,38 +500,26 @@ class World(Task):
 
         logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
 
-        if save_attention_image is not None:
-            for k in range(10):
-                ns = torch.randint(self.test_input.size(0), (1,)).item()
-                input = self.test_input[ns : ns + 1].clone()
+        ##############################
 
-                with torch.autograd.no_grad():
-                    t = model.training
-                    model.eval()
-                    # model.record_attention(True)
-                    model(BracketedSequence(input))
-                    model.train(t)
-                    # ram = model.retrieve_attention()
-                    # model.record_attention(False)
+        input, ar_mask = self.test_input[:64], self.test_ar_mask[:64]
+        result = input.clone() * (1 - ar_mask)
 
-                # tokens_output = [c for c in self.problem.seq2str(input[0])]
-                # tokens_input = ["n/a"] + tokens_output[:-1]
-                # for n_head in range(ram[0].size(1)):
-                # filename = os.path.join(
-                # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
-                # )
-                # attention_matrices = [m[0, n_head] for m in ram]
-                # save_attention_image(
-                # filename,
-                # tokens_input,
-                # tokens_output,
-                # attention_matrices,
-                # k_top=10,
-                ##min_total_attention=0.9,
-                # token_gap=12,
-                # layer_gap=50,
-                # )
-                # logger(f"wrote {filename}")
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis,
+            progress_bar_desc=None,
+            device=self.device,
+        )
+
+        img = world.sample2img(result.to("cpu"), self.height, self.width)
+
+        image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")
+        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
+        logger(f"wrote {image_name}")
 
 
 ######################################################################