Update.
[culture.git] / tasks.py
index 77493a8..f6d34a8 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -14,9 +14,6 @@ from torch.nn import functional as F
 
 from mygpt import BracketedSequence
 
-# from graph import save_attention_image
-save_attention_image = None
-
 ######################################################################
 
 
@@ -220,7 +217,7 @@ class World(Task):
         self.save_image(
             result[:96],
             result_dir,
-            f"world_result_{n_epoch:04d}_{model.id:02d}.png",
+            f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
             logger,
         )
 
@@ -252,7 +249,7 @@ class World(Task):
             new_quizzes,
             ar_mask,
             deterministic_synthesis=False,
-            progress_bar_desc="new quizzes",
+            progress_bar_desc="creating quizzes",
             device=self.device,
         )
 
@@ -273,6 +270,29 @@ class World(Task):
                 device=self.device,
             )
 
-            nb_correct += (new_quizzes == result).long().min(dim=-1).values
+            l = self.height * self.width
+            direction = new_quizzes[:, l : l + 1]
+            direction = world.token_forward * (
+                direction == world.token_backward
+            ) + world.token_backward * (direction == world.token_forward)
+            inverted_quizzes = torch.cat(
+                [new_quizzes[:, l + 1 :], direction, new_quizzes[:, :l]], dim=1
+            )
+
+            inverted_result = inverted_quizzes.clone()
+
+            masked_inplace_autoregression(
+                m,
+                self.batch_size,
+                inverted_result,
+                ar_mask,
+                deterministic_synthesis=True,
+                progress_bar_desc="solving reversed quizzes",
+                device=self.device,
+            )
+
+            nb_correct += (new_quizzes == result).long().min(dim=-1).values * (
+                inverted_quizzes == inverted_result
+            ).long().min(dim=-1).values
 
         return new_quizzes, nb_correct