Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 314a961..2ed6b6b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -37,7 +37,7 @@ parser.add_argument(
 
 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
 
-parser.add_argument("--result_dir", type=str, default="results_default")
+parser.add_argument("--result_dir", type=str, default=None)
 
 parser.add_argument("--seed", type=int, default=0)
 
@@ -113,7 +113,9 @@ parser.add_argument("--stack_nb_steps", type=int, default=100)
 
 parser.add_argument("--stack_nb_stacks", type=int, default=1)
 
-parser.add_argument("--stack_nb_digits", type=int, default=1)
+parser.add_argument("--stack_nb_digits", type=int, default=3)
+
+parser.add_argument("--stack_fraction_values_for_train", type=float, default=None)
 
 ######################################################################
 
@@ -142,30 +144,35 @@ if args.seed >= 0:
 
 default_args = {
     "picoclvr": {
+        "result_dir": "results_picoclvr",
         "nb_epochs": 25,
         "batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "mnist": {
+        "result_dir": "results_mnist",
         "nb_epochs": 25,
         "batch_size": 10,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "maze": {
+        "result_dir": "results_maze",
         "nb_epochs": 25,
         "batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "snake": {
+        "result_dir": "results_snake",
         "nb_epochs": 5,
         "batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "stack": {
+        "result_dir": "results_stack",
         "nb_epochs": 5,
         "batch_size": 25,
         "nb_train_samples": 100000,
@@ -876,6 +883,7 @@ class TaskStack(Task):
         nb_steps,
         nb_stacks,
         nb_digits,
+        fraction_values_for_train=None,
         device=torch.device("cpu"),
     ):
         self.batch_size = batch_size
@@ -884,12 +892,31 @@ class TaskStack(Task):
         self.nb_digits = nb_digits
         self.device = device
 
+        if fraction_values_for_train is None:
+            values_for_train = None
+            values_for_test = None
+        else:
+            all = torch.randperm(10**nb_digits)
+            nb_for_train = int(all.size(0) * fraction_values_for_train)
+            values_for_train = all[:nb_for_train]
+            values_for_test = all[nb_for_train:]
+
         self.train_input, self.train_stack_counts = stack.generate_sequences(
-            nb_train_samples, nb_steps, nb_stacks, nb_digits, self.device
+            nb_train_samples,
+            nb_steps,
+            nb_stacks,
+            nb_digits,
+            values_for_train,
+            self.device,
         )
 
         self.test_input, self.test_stack_counts = stack.generate_sequences(
-            nb_test_samples, nb_steps, nb_stacks, nb_digits, self.device
+            nb_test_samples,
+            nb_steps,
+            nb_stacks,
+            nb_digits,
+            values_for_test,
+            self.device,
         )
 
         mask = self.test_input.clone()
@@ -946,7 +973,9 @@ class TaskStack(Task):
             )
 
             #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-            input = self.test_input[:10, :20]
+            l = 50
+            l = l - l % (1 + self.nb_digits)
+            input = self.test_input[:10, :l]
             result = input.clone()
             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
             ar_mask = (result != input).long()
@@ -1038,6 +1067,7 @@ elif args.task == "stack":
         nb_steps=args.stack_nb_steps,
         nb_stacks=args.stack_nb_stacks,
         nb_digits=args.stack_nb_digits,
+        fraction_values_for_train=args.stack_fraction_values_for_train,
         device=device,
     )