X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=e7c2f75897160bec75e1ec2bb91c762a4be29b04;hb=d6f73f1d5093fb098e822e14db382dd3a1c63a2a;hp=73f61bf102dd46387bde2ecd99f0d7a74c2d7250;hpb=a3211f96c7426a613b82a2de87d4dd70640e8f46;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 73f61bf..e7c2f75 100755 --- a/tasks.py +++ b/tasks.py @@ -76,7 +76,7 @@ class Problem: class ProblemLevel0(Problem): def __init__(self, nb_sentences=100, len_prompt=5, len_result=5): - self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) + self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result)) self.seq[:, len_prompt] = 10 def generate_sequences(self, nb):