X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=problems.py;h=446e1a1435b6c0be95b14fcbe3a655299ce5640a;hb=refs%2Fheads%2Fmaster;hp=d7dbc542aa14b0c77999b7b7a55ece9394e57c90;hpb=cd3329fc206bacfd90a8e2cbe364244359568733;p=picoclvr.git diff --git a/problems.py b/problems.py index d7dbc54..446e1a1 100755 --- a/problems.py +++ b/problems.py @@ -200,9 +200,11 @@ class ProblemTwoTargets(Problem): class ProblemByHeart(Problem): - def __init__(self, nb_sentences=100, len_prompt=8, len_result=8): - self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result)) - self.seq[:, len_prompt] = 10 + def __init__(self, nb_sentences=100, len_prompt=8, len_result=8, separation=1): + self.seq = torch.randint( + 10, (nb_sentences, len_prompt + separation + len_result) + ) + self.seq[:, len_prompt : len_prompt + separation] = 10 def generate_sequences(self, nb): sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]