Update.
[picoclvr.git] / problems.py
index d7dbc54..446e1a1 100755 (executable)
@@ -200,9 +200,11 @@ class ProblemTwoTargets(Problem):
 
 
 class ProblemByHeart(Problem):
-    def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
-        self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
-        self.seq[:, len_prompt] = 10
+    def __init__(self, nb_sentences=100, len_prompt=8, len_result=8, separation=1):
+        self.seq = torch.randint(
+            10, (nb_sentences, len_prompt + separation + len_result)
+        )
+        self.seq[:, len_prompt : len_prompt + separation] = 10
 
     def generate_sequences(self, nb):
         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]