Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 25 Aug 2023 16:38:22 +0000 (18:38 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 25 Aug 2023 16:38:22 +0000 (18:38 +0200)
grid.py
main.py
tasks.py

diff --git a/grid.py b/grid.py
index 08ddc23..70f7739 100755 (executable)
--- a/grid.py
+++ b/grid.py
@@ -118,7 +118,7 @@ class GridFactory:
 
         return properties
 
-    def generate_example(self):
+    def generate_scene_and_questions(self):
         while True:
             while True:
                 scene = self.generate_scene()
@@ -142,25 +142,51 @@ class GridFactory:
                 if len(false) >= self.nb_questions:
                     break
 
+            # print(f"{a=}")
+
             if a < 10:
                 break
 
         true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
         false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
-        true = [(q, "yes") for q in true]
-        false = [(q, "no") for q in false]
+        true = ["<prop> " + q + " <true>" for q in true]
+        false = ["<prop> " + q + " <false>" for q in false]
 
         union = true + false
         questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
 
-        return scene, questions
+        result = " ".join(
+            ["<obj> " + x for x in self.grid_positions(scene)] + questions
+        )
+
+        return scene, result
+
+    def generate_samples(self, nb, progress_bar=None):
+        result = []
+
+        r = range(nb)
+        if progress_bar is not None:
+            r = progress_bar(r)
+
+        for _ in r:
+            result.append(self.generate_scene_and_questions()[1])
+
+        return result
 
 
 ######################################################################
 
 if __name__ == "__main__":
+    import time
+
     grid_factory = GridFactory()
-    scene, questions = grid_factory.generate_example()
+
+    start_time = time.perf_counter()
+    samples = grid_factory.generate_samples(10000)
+    end_time = time.perf_counter()
+    print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
+
+    scene, questions = grid_factory.generate_scene_and_questions()
     grid_factory.print_scene(scene)
     print(questions)
 
diff --git a/main.py b/main.py
index ff831f4..00e19ac 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -33,7 +33,7 @@ parser.add_argument(
     "--task",
     type=str,
     default="twotargets",
-    help="byheart, learnop, guessop, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl",
+    help="byheart, learnop, guessop, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid",
 )
 
 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
@@ -262,6 +262,13 @@ default_task_args = {
         "nb_train_samples": 25000,
         "nb_test_samples": 1000,
     },
+    "grid": {
+        "model": "37M",
+        "nb_epochs": 25,
+        "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
+    },
 }
 
 if args.task in default_task_args:
@@ -505,6 +512,17 @@ elif args.task == "rpl":
         device=device,
     )
 
+elif args.task == "grid":
+    task = tasks.Grid(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        height=args.picoclvr_height,
+        width=args.picoclvr_width,
+        logger=log_string,
+        device=device,
+    )
+
 elif args.task == "world":
     task = tasks.World(
         nb_train_samples=args.nb_train_samples,
index 5019aed..c7348d5 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1419,6 +1419,131 @@ class Expr(Task):
         ##############################################################
 
 
+######################################################################
+
+import grid
+
+
+class Grid(Task):
+    # Make a tensor from a list of strings
+    def tensorize(self, descr):
+        token_descr = [s.strip().split(" ") for s in descr]
+        l = max([len(s) for s in token_descr])
+        token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
+        id_descr = [[self.token2id[u] for u in s] for s in token_descr]
+        return torch.tensor(id_descr, device=self.device)
+
+    # Make a list of strings from a tensor
+    def detensorize(self, x):
+        return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+
+    # trim all the tensors in the tuple z to remove as much token from
+    # left and right in the first tensor. If z is a tuple, all its
+    # elements are trimed according to the triming for the first
+    def trim(self, z, token="<nul>"):
+        n = self.token2id[token]
+        if type(z) == tuple:
+            x = z[0]
+            i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+            return tuple([t[:, a:b] for t in z])
+        else:
+            i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+            return z[:, a:b]
+
+    ######################
+
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        height,
+        width,
+        logger=None,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        self.device = device
+        self.batch_size = batch_size
+        self.grid_factory = grid.GridFactory(height=height, width=width)
+
+        if logger is not None:
+            logger(
+                f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
+            )
+
+        self.train_descr = self.grid_factory.generate_samples(
+            nb_train_samples, lambda r: tqdm.tqdm(r)
+        )
+        self.test_descr = self.grid_factory.generate_samples(
+            nb_test_samples, lambda r: tqdm.tqdm(r)
+        )
+
+        # Build the tokenizer
+        tokens = {}
+        for d in [self.train_descr, self.test_descr]:
+            for s in d:
+                for t in s.strip().split(" "):
+                    tokens.add(t)
+        # make this set a sorted list to get the same tensors given
+        # the same descr
+        tokens = list(tokens)
+        tokens.sort()
+        tokens = ["<nul>"] + tokens
+        self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
+        self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
+        self.t_nul = self.token2id["<nul>"]
+        self.t_true = self.token2id["<true>"]
+        self.t_false = self.token2id["<false>"]
+
+        # Tokenize the train and test sets
+        self.train_input = self.tensorize(self.train_descr)
+        self.test_input = self.tensorize(self.test_descr)
+
+    def batches(self, split="train"):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+        ):
+            yield self.trim(batch)
+
+    def vocabulary_size(self):
+        return len(self.token2id)
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        correct = self.test_input[:1000]
+        result = correct.clone()
+        ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
+        result *= 1 - ar_mask
+
+        for e in self.detensorize(result[:10]):
+            logger(f"test_before {e}")
+
+        masked_inplace_autoregression(
+            model,
+            self.batch_size,
+            result,
+            ar_mask,
+            deterministic_synthesis,
+            device=self.device,
+        )
+
+        for e in self.detensorize(result[:10]):
+            logger(f"test_after {e}")
+
+        nb_total = ar_mask.sum().item()
+        nb_correct = ((correct == result).long() * ar_mask).sum().item()
+
+        logger(f"test_performance {nb_total=} {nb_correct=}")
+        logger(f"main_test_accuracy {nb_correct / nb_total}")
+
+
 ######################################################################
 
 import world