Update.
[picoclvr.git] / tasks.py
index e7c2f75..c5418b4 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -104,7 +104,8 @@ class ProblemLevel1(Problem):
             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
         ) % 10
         marker1 = torch.full((nb, 1), 10)
-        source = torch.randint(10, (nb, self.len_source))
+        # source = torch.randint(10, (nb, self.len_source))
+        source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
         marker2 = torch.full((nb, 1), 11)
         result = operators.bmm(source[:, :, None]).squeeze(-1)
         print(f"{nb_operators.dtype=} {marker1.dtype=}")
@@ -128,7 +129,8 @@ class ProblemLevel2(Problem):
             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
             num_classes=self.len_source,
         )
-        source1 = torch.randint(10, (nb, self.len_source))
+        source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
+        # source1 = torch.randint(10, (nb, self.len_source))
         marker1 = torch.full((nb, 1), 10)
         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
         marker2 = torch.full((nb, 1), 11)