Update.
[culture.git] / tasks.py
index 49b83ec..622cd56 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -2100,7 +2100,7 @@ import world
 
 class World(Task):
     def save_image(self, input, result_dir, filename, logger):
-        img = world.sample2img(self.train_input.to("cpu"), self.height, self.width)
+        img = world.sample2img(input.to("cpu"), self.height, self.width)
         image_name = os.path.join(result_dir, filename)
         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
         logger(f"wrote {image_name}")
@@ -2226,7 +2226,9 @@ class World(Task):
             device=self.device,
         )
 
-        self.save_image(result, result_dir, f"world_result_{n_epoch:04d}.png", logger)
+        self.save_image(
+            result[:96], result_dir, f"world_result_{n_epoch:04d}.png", logger
+        )
 
         return main_test_accuracy
 
@@ -2258,30 +2260,32 @@ class World(Task):
             device=self.device,
         )
 
-        nb_correct = torch.empty(nb, device=self.device, dtype=torch.int64)
+        input = (
+            new_quizzes[:, None, :]
+            .expand(-1, nb_runs, -1)
+            .clone()
+            .reshape(-1, new_quizzes.size(-1))
+        )
+        result = input.clone()
 
-        for n in tqdm.tqdm(
-            range(new_quizzes.size(0)), dynamic_ncols=True, desc="checking quizzes"
-        ):
-            result = new_quizzes[n][None, :].expand(nb_runs, -1).clone()
-            ar_mask = (
-                (torch.arange(result.size(1), device=self.device) > result.size(1) // 2)
-                .long()[None, :]
-                .expand_as(result)
-            )
+        ar_mask = (
+            (torch.arange(result.size(1), device=self.device) > result.size(1) // 2)
+            .long()[None, :]
+            .expand_as(result)
+        )
 
-            masked_inplace_autoregression(
-                model,
-                self.batch_size,
-                result,
-                ar_mask,
-                deterministic_synthesis=False,
-                progress_bar_desc=None,
-                device=self.device,
-            )
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis=False,
+            progress_bar_desc=None,
+            device=self.device,
+        )
 
-            nb_correct[n] = (
-                (new_quizzes[n][None, :] == result).long().min(dim=1).values.sum()
-            )
+        nb_correct = (
+            (input == result).long().min(dim=-1).values.reshape(-1, nb_runs).sum(dim=-1)
+        )
 
         return new_quizzes, nb_correct