class ProblemLevel0(Problem):
def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
- self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
+ self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
self.seq[:, len_prompt] = 10
def generate_sequences(self, nb):