Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 45bddb7..0323d02 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -32,7 +32,7 @@ parser = argparse.ArgumentParser(
 )
 
 parser.add_argument(
-    "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake"
+    "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake, stack"
 )
 
 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
@@ -106,6 +106,15 @@ parser.add_argument("--snake_nb_colors", type=int, default=5)
 
 parser.add_argument("--snake_length", type=int, default=200)
 
+##############################
+# Snake options
+
+parser.add_argument("--stack_nb_steps", type=int, default=25)
+
+parser.add_argument("--stack_nb_stacks", type=int, default=1)
+
+parser.add_argument("--stack_nb_values", type=int, default=10)
+
 ######################################################################
 
 args = parser.parse_args()
@@ -135,18 +144,32 @@ default_args = {
     "picoclvr": {
         "nb_epochs": 25,
         "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
     },
     "mnist": {
         "nb_epochs": 25,
         "batch_size": 10,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
     },
     "maze": {
         "nb_epochs": 25,
         "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
     },
     "snake": {
         "nb_epochs": 5,
         "batch_size": 25,
+        "nb_train_samples": 250000,
+        "nb_test_samples": 10000,
+    },
+    "stack": {
+        "nb_epochs": 25,
+        "batch_size": 25,
+        "nb_train_samples": 10000,
+        "nb_test_samples": 1000,
     },
 }
 
@@ -841,6 +864,86 @@ class TaskSnake(Task):
 ######################################################################
 
 
+import stack
+
+
+class TaskStack(Task):
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        nb_steps,
+        nb_stacks,
+        nb_values,
+        device=torch.device("cpu"),
+    ):
+        self.batch_size = batch_size
+        self.nb_steps = nb_steps
+        self.nb_stacks = nb_stacks
+        self.nb_values = nb_values
+        self.device = device
+
+        self.train_input, self.train_stack_counts = stack.generate_sequences(
+            nb_train_samples, nb_steps, nb_stacks, nb_values, self.device
+        )
+
+        self.test_input, self.test_stack_counts = stack.generate_sequences(
+            nb_test_samples, nb_steps, nb_stacks, nb_values, self.device
+        )
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def produce_results(self, n_epoch, model):
+        with torch.autograd.no_grad():
+            t = model.training
+            model.eval()
+
+            def compute_nb_correct(input):
+                result = input.clone()
+                stack.remove_poped_values(result,self.nb_stacks)
+                ar_mask = (result != input).long()
+                result *= 1 - ar_mask
+
+                masked_inplace_autoregression(
+                    model, self.batch_size, result, ar_mask, device=self.device
+                )
+
+                nb_total = ar_mask.sum()
+
+                nb_correct = (
+                    (result == input).long() * ar_mask
+                ).sum()
+
+                return nb_total, nb_correct
+
+            test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
+
+            log_string(
+                f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+            )
+
+            model.train(t)
+
+
+######################################################################
+
+
 def picoclvr_pruner_horizontal_green(p):
     return not ("green" in p and ("left" in p or "right" in p))
 
@@ -902,6 +1005,17 @@ elif args.task == "snake":
         device=device,
     )
 
+elif args.task == "stack":
+    task = TaskStack(
+        nb_train_samples=args.nb_train_samples,
+        nb_test_samples=args.nb_test_samples,
+        batch_size=args.batch_size,
+        nb_steps = args.stack_nb_steps,
+        nb_stacks = args.stack_nb_stacks,
+        nb_values = args.stack_nb_values,
+        device=device,
+    )
+
 else:
     raise ValueError(f"Unknown task {args.task}")