Update.
[culture.git] / tasks.py
index c0ad5ff..b4e6f67 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -395,6 +395,133 @@ class SandBox(Task):
                 # logger(f"wrote {filename}")
 
 
                 # logger(f"wrote {filename}")
 
 
+######################################################################
+
+import world
+
+
+class World(Task):
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        logger=None,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        self.batch_size = batch_size
+        self.device = device
+        self.height = 6
+        self.width = 8
+
+        self.train_input = world.generate(
+            nb_train_samples, height=self.height, width=self.width
+        )
+        self.train_ar_mask = (
+            (torch.arange(self.train_input.size(1)) > self.train_input.size(1) // 2)
+            .long()[None, :]
+            .expand_as(self.train_input)
+        )
+
+        self.test_input = world.generate(
+            nb_test_samples, height=self.height, width=self.width
+        )
+        self.test_ar_mask = (
+            (torch.arange(self.test_input.size(1)) > self.test_input.size(1) // 2)
+            .long()[None, :]
+            .expand_as(self.test_input)
+        )
+
+        self.train_input, self.train_ar_mask = self.train_input.to(
+            device
+        ), self.train_ar_mask.to(device)
+        self.test_input, self.test_ar_mask = self.test_input.to(
+            device
+        ), self.test_ar_mask.to(device)
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
+    ):
+        def compute_accuracy(input, ar_mask, logger=None):
+            input, ar_mask = input[:nmax], ar_mask[:nmax]
+            result = input.clone() * (1 - ar_mask)
+
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                progress_bar_desc=None,
+                device=self.device,
+            )
+
+            nb_total, nb_correct = (
+                input.size(0),
+                (input == result).long().min(dim=1).values.sum(),
+            )
+
+            return nb_total, nb_correct
+
+        train_nb_total, train_nb_correct = compute_accuracy(
+            self.train_input, self.train_ar_mask
+        )
+
+        logger(
+            f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+        )
+
+        test_nb_total, test_nb_correct = compute_accuracy(
+            self.test_input, self.test_ar_mask, logger
+        )
+
+        logger(
+            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+        )
+
+        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
+        ##############################
+
+        input, ar_mask = self.test_input[:64], self.test_ar_mask[:64]
+        result = input.clone() * (1 - ar_mask)
+
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis,
+            progress_bar_desc=None,
+            device=self.device,
+        )
+
+        img = world.sample2img(result.to("cpu"), self.height, self.width)
+
+        image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")
+        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
+        logger(f"wrote {image_name}")
+
+
 ######################################################################
 
 import picoclvr
 ######################################################################
 
 import picoclvr
@@ -754,15 +881,17 @@ class Maze(Task):
     def compute_error(
         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
     ):
     def compute_error(
         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
     ):
+        model_device = next(model.parameters()).device
         nb_total, nb_correct = 0, 0
         count = torch.zeros(
             self.width * self.height,
             self.width * self.height,
         nb_total, nb_correct = 0, 0
         count = torch.zeros(
             self.width * self.height,
             self.width * self.height,
-            device=self.device,
+            device=model_device,
             dtype=torch.int64,
         )
 
         for input in self.batches(split, nb_to_use):
             dtype=torch.int64,
         )
 
         for input in self.batches(split, nb_to_use):
+            input = input.to(model_device)
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
@@ -836,7 +965,7 @@ class Maze(Task):
                         eol = " " if j < count.size(1) - 1 else "\n"
                         f.write(f"{count[i,j]}{eol}")
 
                         eol = " " if j < count.size(1) - 1 else "\n"
                         f.write(f"{count[i,j]}{eol}")
 
-        input = self.test_input[:48]
+        input = self.test_input[:48].to(next(model.parameters()).device)
         result = input.clone()
         ar_mask = result.new_zeros(result.size())
         ar_mask[:, self.height * self.width :] = 1
         result = input.clone()
         ar_mask = result.new_zeros(result.size())
         ar_mask[:, self.height * self.width :] = 1
@@ -1098,6 +1227,34 @@ class Stack(Task):
             device=self.device,
         )
 
             device=self.device,
         )
 
+        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+        for label, input in [
+            ("train", self.train_input[:32]),
+            ("test", self.test_input[:32]),
+        ]:
+            output = model(BracketedSequence(input)).x
+            output = output.log_softmax(dim=-1)
+            filename = os.path.join(
+                result_dir, f"stack_with_crossentropy_{n_epoch:04d}_{label}.txt"
+            )
+            with open(filename, "w") as f:
+                for n in range(input.size(0)):
+                    s = stack.seq_to_str(
+                        input[n], nb_stacks=self.nb_stacks, nb_digits=self.nb_digits
+                    )
+                    for t, k, w in zip(range(input[n].size(0)), input[n], s.split(" ")):
+                        u = (
+                            " " * (10 - len(w))
+                            + w
+                            + " "
+                            + str(output[n][t][k].exp().item())
+                            + "\n"
+                        )
+                        f.write(u)
+                    f.write("\n")
+            logger(f"wrote {filename}")
+        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
         for n in range(result.size(0)):
             logger(
                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
         for n in range(result.size(0)):
             logger(
                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"