Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 18 Jul 2023 20:35:59 +0000 (22:35 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 18 Jul 2023 20:35:59 +0000 (22:35 +0200)
main.py
tasks.py

diff --git a/main.py b/main.py
index e3fd9f0..0d4930d 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -82,6 +82,17 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 ##############################
 # picoclvr options
 
+parser.add_argument("--sandbox_level", type=int, default=0)
+
+parser.add_argument("--sandbox_levels_nb_items", type=int, default=25)
+
+parser.add_argument("--sandbox_levels_len_source", type=int, default=5)
+
+parser.add_argument("--sandbox_levels_len_result", type=int, default=8)
+
+##############################
+# picoclvr options
+
 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
 
 parser.add_argument("--picoclvr_height", type=int, default=12)
@@ -265,8 +276,28 @@ picoclvr_pruner_eval = (
 ######################################################################
 
 if args.task == "sandbox":
+    if args.sandbox_level == 0:
+        problem = tasks.ProblemLevel0(
+            nb_sentences=args.sandbox_levels_nb_items,
+            len_prompt=args.sandbox_levels_len_source,
+            len_result=args.sandbox_levels_len_result,
+        )
+    elif args.sandbox_level == 1:
+        problem = tasks.ProblemLevel1(
+            nb_operators=args.sandbox_levels_nb_items,
+            len_source=args.sandbox_levels_len_source,
+            len_result=args.sandbox_levels_len_result,
+        )
+    elif args.sandbox_level == 2:
+        problem = tasks.ProblemLevel2(
+            len_source=args.sandbox_levels_len_source,
+            len_result=args.sandbox_levels_len_result,
+        )
+    else:
+        raise ValueError(f"Unknown sandbox level {args.sandbox_level}")
+
     task = tasks.SandBox(
-        tasks.ProblemLevel2(),
+        problem,
         # tasks.ProblemAddition(zero_padded=False, inverted_result=False),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
index 73f61bf..e7c2f75 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -76,7 +76,7 @@ class Problem:
 
 class ProblemLevel0(Problem):
     def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
-        self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
+        self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
         self.seq[:, len_prompt] = 10
 
     def generate_sequences(self, nb):