Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 18 Jul 2023 20:17:30 +0000 (22:17 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 18 Jul 2023 20:17:30 +0000 (22:17 +0200)
tasks.py

index 5ac78cb..706e1d9 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -87,29 +87,28 @@ class ProblemLevel0(Problem):
 
 
 class ProblemLevel1(Problem):
-    def __init__(self, nb_operators=100, len_prompt=5, len_result=8):
-        self.len_prompt = len_prompt
+    def __init__(self, nb_operators=100, len_source=5, len_result=8):
+        self.len_source = len_source
         self.len_result = len_result
         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
         self.operators = F.one_hot(
-            torch.rand(nb_operators, len_result, len_prompt).argmax(-1),
-            num_classes=len_prompt,
+            torch.rand(nb_operators, len_result, len_source).argmax(-1),
+            num_classes=len_source,
         )
 
+
+
     def generate_sequences(self, nb):
-        a = self.len_nb_operator
-        b = a + 1 + self.len_prompt
-        sequences = torch.empty(nb, b + 1 + self.len_result, dtype=torch.int64)
         nb_operators = torch.randint(self.operators.size(0), (nb,))
-        sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a-1,-1,-1)) % 10
-        sequences[:, a] = 10
-        sequences[:, a + 1 : b] = torch.randint(10, (nb, b - a - 1))
-        sequences[:, b] = 11
-
-        o = self.operators[nb_operators]
-        p = sequences[:, a + 1 : b]
-        print(f"{o.size()=} {p.size()=} {sequences[:,b+1:].size()=}")
-        sequences[:, b + 1 :] = o.bmm(p[:, :, None]).squeeze(-1)
+        operators = self.operators[nb_operators]
+        nb_operators = (nb_operators[:, None] // 10 ** torch.arange(self.len_nb_operator-1,-1,-1)) % 10
+        marker1 = torch.full((nb,1),10)
+        source = torch.randint(10, (nb, self.len_source))
+        marker2 = torch.full((nb,1),11)
+        result = operators.bmm(source[:, :, None]).squeeze(-1)
+        print(f"{nb_operators.dtype=} {marker1.dtype=}")
+        sequences = torch.cat((nb_operators, marker1, source,marker2,result),1)
+        print(f"{sequences.size()=}")
         ar_mask = (sequences == 11).long()
         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
         return sequences, ar_mask