- def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
- self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
- self.seq[:, len_prompt] = 10
+ def __init__(self, nb_sentences=100, len_prompt=8, len_result=8, separation=1):
+ self.seq = torch.randint(
+ 10, (nb_sentences, len_prompt + separation + len_result)
+ )
+ self.seq[:, len_prompt : len_prompt + separation] = 10