Update.
[picoclvr.git] / problems.py
index aa3acf0..2e0ca36 100755 (executable)
@@ -47,14 +47,22 @@ class ProblemTwoTargets(Problem):
         a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
         a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
         sequences = torch.cat(
-            (s, torch.full((nb, 1), 12), a1, torch.full((nb, 1), 12), a2), 1
+            (
+                s,
+                torch.full((nb, 1), 12),
+                a1,
+                torch.full((nb, 1), 12),
+                a2,
+                torch.full((nb, 1), 12),
+            ),
+            1,
         )
         ar_mask = (sequences == 12).long()
         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
         return sequences, ar_mask
 
     def seq2str(self, seq):
-        return "".join("0123456789+-|"[x.item()] for x in seq)
+        return "".join("0123456789-+|"[x.item()] for x in seq)
 
 
 ####################