Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 8c4b7a1..2ed6b6b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -37,7 +37,7 @@ parser.add_argument(
 
 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
 
-parser.add_argument("--result_dir", type=str, default="results_default")
+parser.add_argument("--result_dir", type=str, default=None)
 
 parser.add_argument("--seed", type=int, default=0)
 
@@ -144,30 +144,35 @@ if args.seed >= 0:
 
 default_args = {
     "picoclvr": {
+        "result_dir": "results_picoclvr",
         "nb_epochs": 25,
         "batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "mnist": {
+        "result_dir": "results_mnist",
         "nb_epochs": 25,
         "batch_size": 10,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "maze": {
+        "result_dir": "results_maze",
         "nb_epochs": 25,
         "batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "snake": {
+        "result_dir": "results_snake",
         "nb_epochs": 5,
         "batch_size": 25,
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
     "stack": {
+        "result_dir": "results_stack",
         "nb_epochs": 5,
         "batch_size": 25,
         "nb_train_samples": 100000,
@@ -968,8 +973,8 @@ class TaskStack(Task):
             )
 
             #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-            l=50
-            l=l-l%(1+self.nb_digits)
+            l = 50
+            l = l - l % (1 + self.nb_digits)
             input = self.test_input[:10, :l]
             result = input.clone()
             stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)