class ProblemLevel0(Problem):
def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
- self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
+ self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
self.seq[:, len_prompt] = 10
def generate_sequences(self, nb):
// 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
) % 10
marker1 = torch.full((nb, 1), 10)
- source = torch.randint(10, (nb, self.len_source))
+ # source = torch.randint(10, (nb, self.len_source))
+ source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
marker2 = torch.full((nb, 1), 11)
result = operators.bmm(source[:, :, None]).squeeze(-1)
print(f"{nb_operators.dtype=} {marker1.dtype=}")
torch.rand(nb, self.len_result, self.len_source).argmax(-1),
num_classes=self.len_source,
)
- source1 = torch.randint(10, (nb, self.len_source))
+ source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
+ # source1 = torch.randint(10, (nb, self.len_source))
marker1 = torch.full((nb, 1), 10)
result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
marker2 = torch.full((nb, 1), 11)