Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 18 Jul 2023 20:23:03 +0000 (22:23 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 18 Jul 2023 20:23:03 +0000 (22:23 +0200)
main.py
tasks.py

diff --git a/main.py b/main.py
index 213524e..e3fd9f0 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -266,7 +266,7 @@ picoclvr_pruner_eval = (
 
 if args.task == "sandbox":
     task = tasks.SandBox(
-        tasks.ProblemLevel1(),
+        tasks.ProblemLevel2(),
         # tasks.ProblemAddition(zero_padded=False, inverted_result=False),
         nb_train_samples=args.nb_train_samples,
         nb_test_samples=args.nb_test_samples,
index 706e1d9..73f61bf 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -96,18 +96,19 @@ class ProblemLevel1(Problem):
             num_classes=len_source,
         )
 
-
-
     def generate_sequences(self, nb):
         nb_operators = torch.randint(self.operators.size(0), (nb,))
         operators = self.operators[nb_operators]
-        nb_operators = (nb_operators[:, None] // 10 ** torch.arange(self.len_nb_operator-1,-1,-1)) % 10
-        marker1 = torch.full((nb,1),10)
+        nb_operators = (
+            nb_operators[:, None]
+            // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
+        ) % 10
+        marker1 = torch.full((nb, 1), 10)
         source = torch.randint(10, (nb, self.len_source))
-        marker2 = torch.full((nb,1),11)
+        marker2 = torch.full((nb, 1), 11)
         result = operators.bmm(source[:, :, None]).squeeze(-1)
         print(f"{nb_operators.dtype=} {marker1.dtype=}")
-        sequences = torch.cat((nb_operators, marker1, source,marker2,result),1)
+        sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
         print(f"{sequences.size()=}")
         ar_mask = (sequences == 11).long()
         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
@@ -117,6 +118,35 @@ class ProblemLevel1(Problem):
         return "".join("0123456789|>"[x.item()] for x in seq)
 
 
+class ProblemLevel2(Problem):
+    def __init__(self, len_source=5, len_result=8):
+        self.len_source = len_source
+        self.len_result = len_result
+
+    def generate_sequences(self, nb):
+        operators = F.one_hot(
+            torch.rand(nb, self.len_result, self.len_source).argmax(-1),
+            num_classes=self.len_source,
+        )
+        source1 = torch.randint(10, (nb, self.len_source))
+        marker1 = torch.full((nb, 1), 10)
+        result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
+        marker2 = torch.full((nb, 1), 11)
+        source2 = torch.randint(10, (nb, self.len_source))
+        marker3 = torch.full((nb, 1), 12)
+        result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
+
+        sequences = torch.cat(
+            (source1, marker1, result1, marker2, source2, marker3, result2), 1
+        )
+        ar_mask = (sequences == 12).long()
+        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+        return sequences, ar_mask
+
+    def seq2str(self, seq):
+        return "".join("0123456789>|~"[x.item()] for x in seq)
+
+
 ####################