X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=problems.py;fp=problems.py;h=446e1a1435b6c0be95b14fcbe3a655299ce5640a;hb=798d9526e726b644979cf1124e714f705fdd5966;hp=d7dbc542aa14b0c77999b7b7a55ece9394e57c90;hpb=3528c66810984055a0e0f0cf7a4169c3340be0c8;p=picoclvr.git diff --git a/problems.py b/problems.py index d7dbc54..446e1a1 100755 --- a/problems.py +++ b/problems.py @@ -200,9 +200,11 @@ class ProblemTwoTargets(Problem): class ProblemByHeart(Problem): - def __init__(self, nb_sentences=100, len_prompt=8, len_result=8): - self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result)) - self.seq[:, len_prompt] = 10 + def __init__(self, nb_sentences=100, len_prompt=8, len_result=8, separation=1): + self.seq = torch.randint( + 10, (nb_sentences, len_prompt + separation + len_result) + ) + self.seq[:, len_prompt : len_prompt + separation] = 10 def generate_sequences(self, nb): sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]