Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 18 Feb 2024 21:39:54 +0000 (22:39 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 18 Feb 2024 21:39:54 +0000 (22:39 +0100)
main.py
tasks.py

diff --git a/main.py b/main.py
index 9f82594..55f2c2f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -33,7 +33,7 @@ parser.add_argument(
     "--task",
     type=str,
     default="twotargets",
-    help="byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
+    help="file, byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
 )
 
 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
@@ -86,6 +86,11 @@ parser.add_argument("--overwrite_results", action="store_true", default=False)
 
 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 
+##############################
+# filetask
+
+parser.add_argument("--filetask_file", type=str, default=None)
+
 ##############################
 # rpl options
 
@@ -180,6 +185,12 @@ if args.result_dir is None:
 ######################################################################
 
 default_task_args = {
+    "file": {
+        "model": "37M",
+        "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
+    },
     "addition": {
         "model": "352M",
         "batch_size": 25,
@@ -390,7 +401,20 @@ picoclvr_pruner_eval = (
 
 ######################################################################
 
-if args.task == "byheart":
+if args.task == "file":
+    assert (
+        args.filetask_file is not None
+    ), "You have to specify the task file with --filetask_file <filename>"
+    task = tasks.TaskFromFile(
+        args.filetask_file,
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        device=device,
+    )
+    args.max_percents_of_test_in_train = 0
+
+elif args.task == "byheart":
     task = tasks.SandBox(
         problem=problems.ProblemByHeart(),
         nb_train_samples=args.nb_train_samples,
index 08aa8ca..00b7a49 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -70,6 +70,141 @@ class Task:
         pass
 
 
+class TaskFromFile(Task):
+    def tensorize(self, pairs):
+        len_max = max([len(x[0]) for x in pairs])
+
+        input = torch.cat(
+            [
+                torch.tensor(
+                    [
+                        [self.char2id[c] for c in s[0] + "#" * (len_max - len(s[0]))]
+                        for s in pairs
+                    ]
+                )
+            ],
+            0,
+        ).to("cpu")
+
+        pred_mask = torch.cat(
+            [
+                torch.tensor(
+                    [
+                        [int(c) for c in s[1] + "0" * (len_max - len(s[1]))]
+                        for s in pairs
+                    ]
+                )
+            ],
+            0,
+        ).to("cpu")
+
+        return input, pred_mask
+
+    # trim all the tensors in the tuple z to remove as much token from
+    # left and right in the first tensor. If z is a tuple, all its
+    # elements are trimed according to the triming for the first
+    def trim(self, z, token="#"):
+        n = self.char2id[token]
+        if type(z) == tuple:
+            x = z[0]
+            i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+            return tuple([t[:, a:b] for t in z])
+        else:
+            i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+            return z[:, a:b]
+
+    def __init__(
+        self,
+        filename,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        device=torch.device("cpu"),
+    ):
+        self.batch_size = batch_size
+        self.device = device
+
+        pairs = []
+        with open(filename, "r") as f:
+            for _ in range(nb_train_samples + nb_test_samples):
+                sequence = f.readline().strip()
+                pred_mask = f.readline().strip()
+                assert len(sequence) == len(pred_mask)
+                assert set(pred_mask) == {"0", "1", "2"}, f"{set(pred_mask)}"
+                pairs.append((sequence, pred_mask))
+
+        symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"]))
+        print("SANITY", symbols)
+        self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
+        self.id2char = dict([(n, c) for c, n in self.char2id.items()])
+
+        print(self.char2id)
+
+        self.train_input, self.train_pred_masks = self.tensorize(
+            pairs[:nb_train_samples]
+        )
+        self.test_input, self.test_pred_masks = self.tensorize(pairs[nb_train_samples:])
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield self.trim(batch).to(self.device)
+
+    def vocabulary_size(self):
+        return len(self.char2id)
+
+    def tensor2str(self, t):
+        print(f"{type(t)=}")
+        return ["".join([self.id2char[x.item()] for x in s]) for s in t]
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        correct = self.trim(self.test_input[:1000]).to(self.device)
+        result = correct.clone()
+        pred_mask = self.test_pred_masks[:1000, : result.size(1)].to(self.device)
+        ar_mask = (pred_mask > 0).long()
+        result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
+
+        logger(f"----------------------------------------------------------")
+
+        for e in self.tensor2str(result[:10]):
+            logger(f"test_before {e}")
+
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis,
+            device=self.device,
+        )
+
+        logger(f"----------------------------------------------------------")
+
+        for e, c in zip(self.tensor2str(result[:10]), self.tensor2str(correct[:10])):
+            logger(f"test_after  {e}")
+            logger(f"correct     {c}")
+
+        logger(f"----------------------------------------------------------")
+
+        err_mask = (pred_mask == 2).long()
+        nb_total = err_mask.sum().item()
+        nb_correct = ((correct == result).long() * err_mask).sum().item()
+
+        logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
+        logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
+
+
 ####################
 
 import problems
@@ -1484,6 +1619,7 @@ class Grid(Task):
         self.device = device
         self.batch_size = batch_size
         self.grid_factory = grid.GridFactory(size=size)
+        self.fraction_play = fraction_play
 
         if logger is not None:
             logger(
@@ -1495,13 +1631,21 @@ class Grid(Task):
             fraction_play=fraction_play,
             progress_bar=lambda r: tqdm.tqdm(r),
         )
+
         self.test_descr = self.grid_factory.generate_samples(
             nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
         )
 
+        if fraction_play > 0:
+            self.play_descr = self.grid_factory.generate_samples(
+                nb=25, fraction_play=1.0, progress_bar=lambda r: tqdm.tqdm(r)
+            )
+        else:
+            self.play_descr = []
+
         # Build the tokenizer
         tokens = set()
-        for d in [self.train_descr, self.test_descr]:
+        for d in [self.train_descr, self.test_descr, self.play_descr]:
             for s in d:
                 for t in s.strip().split(" "):
                     tokens.add(t)
@@ -1515,10 +1659,14 @@ class Grid(Task):
         self.t_nul = self.token2id["#"]
         self.t_true = self.token2id["true"]
         self.t_false = self.token2id["false"]
+        self.t_pipe = self.token2id["|"]
 
         # Tokenize the train and test sets
         self.train_input = self.str2tensor(self.train_descr)
         self.test_input = self.str2tensor(self.test_descr)
+        self.play_input = (
+            None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
+        )
 
     def batches(self, split="train"):
         assert split in {"train", "test"}
@@ -1566,41 +1714,31 @@ class Grid(Task):
         logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
         logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
 
-        if n_epoch == 5 or n_epoch == 10 or n_epoch == 20:
-            if save_attention_image is None:
-                logger("no save_attention_image (is pycairo installed?)")
-            else:
-                for k in range(10):
-                    ns = k  # torch.randint(self.test_input.size(0), (1,)).item()
-                    input = self.test_input[ns : ns + 1].clone()
-                    with torch.autograd.no_grad():
-                        t = model.training
-                        model.eval()
-                        model.record_attention(True)
-                        model(BracketedSequence(input))
-                        model.train(t)
-                        ram = model.retrieve_attention()
-                        model.record_attention(False)
-
-                    tokens_output = [self.id2token[t.item()] for t in input[0]]
-                    tokens_input = ["n/a"] + tokens_output[:-1]
-                    for n_head in range(ram[0].size(1)):
-                        filename = os.path.join(
-                            result_dir,
-                            f"sandbox_attention_epoch_{n_epoch}_sample_{k}_head_{n_head}.pdf",
-                        )
-                        attention_matrices = [m[0, n_head] for m in ram]
-                        save_attention_image(
-                            filename,
-                            tokens_input,
-                            tokens_output,
-                            attention_matrices,
-                            k_top=10,
-                            # min_total_attention=0.9,
-                            token_gap=12,
-                            layer_gap=50,
-                        )
-                        logger(f"wrote {filename}")
+        if self.play_input is not None:
+            result = self.play_input.clone()
+            ar_mask = (result == self.t_pipe).long().cumsum(dim=1).clamp(max=1)
+            result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
+
+            logger(f"----------------------------------------------------------")
+
+            for e in self.tensor2str(result[:10]):
+                logger(f"play_before {e}")
+
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                device=self.device,
+            )
+
+            logger(f"----------------------------------------------------------")
+
+            for e in self.tensor2str(result[:10]):
+                logger(f"play_after  {e}")
+
+            logger(f"----------------------------------------------------------")
 
 
 ######################################################################