Update.
[culture.git] / tasks.py
index 324376d..622cd56 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -63,7 +63,7 @@ def masked_inplace_autoregression(
 
 
 class Task:
-    def batches(self, split="train"):
+    def batches(self, split="train", nb_to_use=-1, desc=None):
         pass
 
     def vocabulary_size(self):
@@ -489,7 +489,7 @@ class PicoCLVR(Task):
         self.train_input = self.tensorize(self.train_descr)
         self.test_input = self.tensorize(self.test_descr)
 
-    def batches(self, split="train"):
+    def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
         for batch in tqdm.tqdm(
@@ -754,15 +754,17 @@ class Maze(Task):
     def compute_error(
         self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
     ):
+        model_device = next(model.parameters()).device
         nb_total, nb_correct = 0, 0
         count = torch.zeros(
             self.width * self.height,
             self.width * self.height,
-            device=self.device,
+            device=model_device,
             dtype=torch.int64,
         )
 
         for input in self.batches(split, nb_to_use):
+            input = input.to(model_device)
             result = input.clone()
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
@@ -836,7 +838,7 @@ class Maze(Task):
                         eol = " " if j < count.size(1) - 1 else "\n"
                         f.write(f"{count[i,j]}{eol}")
 
-        input = self.test_input[:48]
+        input = self.test_input[:48].to(next(model.parameters()).device)
         result = input.clone()
         ar_mask = result.new_zeros(result.size())
         ar_mask[:, self.height * self.width :] = 1
@@ -1098,6 +1100,34 @@ class Stack(Task):
             device=self.device,
         )
 
+        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+        for label, input in [
+            ("train", self.train_input[:32]),
+            ("test", self.test_input[:32]),
+        ]:
+            output = model(BracketedSequence(input)).x
+            output = output.log_softmax(dim=-1)
+            filename = os.path.join(
+                result_dir, f"stack_with_crossentropy_{n_epoch:04d}_{label}.txt"
+            )
+            with open(filename, "w") as f:
+                for n in range(input.size(0)):
+                    s = stack.seq_to_str(
+                        input[n], nb_stacks=self.nb_stacks, nb_digits=self.nb_digits
+                    )
+                    for t, k, w in zip(range(input[n].size(0)), input[n], s.split(" ")):
+                        u = (
+                            " " * (10 - len(w))
+                            + w
+                            + " "
+                            + str(output[n][t][k].exp().item())
+                            + "\n"
+                        )
+                        f.write(u)
+                    f.write("\n")
+            logger(f"wrote {filename}")
+        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
         for n in range(result.size(0)):
             logger(
                 f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
@@ -1685,7 +1715,7 @@ class Grid(Task):
         self.t_nul = self.token2id["#"]
         self.t_true = self.token2id["true"]
         self.t_false = self.token2id["false"]
-        self.t_pipe = self.token2id["|"]
+        self.t_pipe = self.token2id["|"]
 
         # Tokenize the train and test sets
         self.train_input = self.str2tensor(self.train_descr)
@@ -1694,7 +1724,7 @@ class Grid(Task):
             None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
         )
 
-    def batches(self, split="train"):
+    def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
         for batch in tqdm.tqdm(
@@ -1823,7 +1853,7 @@ class QMLP(Task):
 
         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
 
-    def batches(self, split="train"):
+    def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
         for batch in tqdm.tqdm(
@@ -1944,7 +1974,7 @@ class Greed(Task):
                 progress_bar_desc=None,
             )
             warnings.warn("keeping thinking snapshots", RuntimeWarning)
-            snapshots.append(result[:10].detach().clone())
+            snapshots.append(result[:100].detach().clone())
 
         # Generate iteration after iteration
 
@@ -1986,11 +2016,11 @@ class Greed(Task):
             # Set the lookahead_reward to UNKNOWN for the next iterations
             result[
                 :, u + self.world.index_lookahead_reward
-            ] = self.world.lookahead_reward2code(gree.REWARD_UNKNOWN)
+            ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
 
         filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt")
         with open(filename, "w") as f:
-            for n in range(10):
+            for n in range(snapshots[0].size(0)):
                 for s in snapshots:
                     lr, s, a, r = self.world.seq2episodes(
                         s[n : n + 1],
@@ -2063,3 +2093,199 @@ class Greed(Task):
 
 
 ######################################################################
+######################################################################
+
+import world
+
+
+class World(Task):
+    def save_image(self, input, result_dir, filename, logger):
+        img = world.sample2img(input.to("cpu"), self.height, self.width)
+        image_name = os.path.join(result_dir, filename)
+        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
+        logger(f"wrote {image_name}")
+
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        result_dir=None,
+        logger=None,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        self.batch_size = batch_size
+        self.device = device
+        self.height = 6
+        self.width = 8
+
+        self.train_input = world.generate(
+            nb_train_samples, height=self.height, width=self.width
+        )
+        self.train_ar_mask = (
+            (torch.arange(self.train_input.size(1)) > self.train_input.size(1) // 2)
+            .long()[None, :]
+            .expand_as(self.train_input)
+        )
+
+        self.test_input = world.generate(
+            nb_test_samples, height=self.height, width=self.width
+        )
+        self.test_ar_mask = (
+            (torch.arange(self.test_input.size(1)) > self.test_input.size(1) // 2)
+            .long()[None, :]
+            .expand_as(self.test_input)
+        )
+
+        self.train_input, self.train_ar_mask = self.train_input.to(
+            device
+        ), self.train_ar_mask.to(device)
+        self.test_input, self.test_ar_mask = self.test_input.to(
+            device
+        ), self.test_ar_mask.to(device)
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+        if result_dir is not None:
+            self.save_image(
+                self.train_input[:96], result_dir, f"world_train.png", logger
+            )
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
+    ):
+        def compute_accuracy(input, ar_mask, logger=None):
+            input, ar_mask = input[:nmax], ar_mask[:nmax]
+            result = input.clone() * (1 - ar_mask)
+
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                progress_bar_desc=None,
+                device=self.device,
+            )
+
+            nb_total, nb_correct = (
+                input.size(0),
+                (input == result).long().min(dim=1).values.sum(),
+            )
+
+            return nb_total, nb_correct
+
+        train_nb_total, train_nb_correct = compute_accuracy(
+            self.train_input, self.train_ar_mask
+        )
+
+        logger(
+            f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+        )
+
+        test_nb_total, test_nb_correct = compute_accuracy(
+            self.test_input, self.test_ar_mask, logger
+        )
+
+        logger(
+            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+        )
+
+        main_test_accuracy = test_nb_correct / test_nb_total
+        logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
+
+        ##############################
+
+        input, ar_mask = self.test_input[:96], self.test_ar_mask[:96]
+        result = input.clone() * (1 - ar_mask)
+
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis,
+            progress_bar_desc=None,
+            device=self.device,
+        )
+
+        self.save_image(
+            result[:96], result_dir, f"world_result_{n_epoch:04d}.png", logger
+        )
+
+        return main_test_accuracy
+
+    def store_new_quizzes(self, new_quizzes, for_train=True):
+        input = self.train_input if for_train else self.test_input
+
+        nb_current = input.size(0)
+        nb_new = new_quizzes.size(0)
+        if nb_new >= nb_current:
+            input[...] = new_quizzes[:nb_current]
+        else:
+            nb_kept = nb_current - nb_new
+            input[:nb_kept] = input[-nb_kept:].clone()
+            input[nb_kept:] = new_quizzes
+
+    def create_new_quizzes(self, n_epoch, result_dir, logger, nb, model, nb_runs):
+        new_quizzes = torch.empty(
+            nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
+        )
+        ar_mask = torch.full(new_quizzes.size(), 1, device=self.device)
+
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            new_quizzes,
+            ar_mask,
+            deterministic_synthesis=False,
+            progress_bar_desc="new quizzes",
+            device=self.device,
+        )
+
+        input = (
+            new_quizzes[:, None, :]
+            .expand(-1, nb_runs, -1)
+            .clone()
+            .reshape(-1, new_quizzes.size(-1))
+        )
+        result = input.clone()
+
+        ar_mask = (
+            (torch.arange(result.size(1), device=self.device) > result.size(1) // 2)
+            .long()[None, :]
+            .expand_as(result)
+        )
+
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis=False,
+            progress_bar_desc=None,
+            device=self.device,
+        )
+
+        nb_correct = (
+            (input == result).long().min(dim=-1).values.reshape(-1, nb_runs).sum(dim=-1)
+        )
+
+        return new_quizzes, nb_correct