class ProblemLevel0(Problem):
def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
- self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
+ self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
self.seq[:, len_prompt] = 10
def generate_sequences(self, nb):
class ProblemLevel1(Problem):
- def __init__(self, nb_operators=100, len_prompt=5, len_result=8):
- self.len_prompt = len_prompt
+ def __init__(self, nb_operators=100, len_source=5, len_result=8):
+ self.len_source = len_source
self.len_result = len_result
self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
self.operators = F.one_hot(
- torch.rand(nb_operators, len_result, len_prompt).argmax(-1),
- num_classes=len_prompt,
+ torch.rand(nb_operators, len_result, len_source).argmax(-1),
+ num_classes=len_source,
)
def generate_sequences(self, nb):
- a = self.len_nb_operator
- b = a + 1 + self.len_prompt
- sequences = torch.empty(nb, b + 1 + self.len_result, dtype=torch.int64)
nb_operators = torch.randint(self.operators.size(0), (nb,))
- sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a)) % 10
- sequences[:, a] = 10
- sequences[:, a + 1 : b] = torch.randint(10, (nb, b - a - 1))
- sequences[:, b] = 11
-
- o = self.operators[nb_operators]
- p = sequences[:, a + 1 : b]
- print(f"{o.size()=} {p.size()=} {sequences[:,b+1:].size()=}")
- sequences[:, b + 1 :] = o.bmm(p[:, :, None]).squeeze(-1)
+ operators = self.operators[nb_operators]
+ nb_operators = (
+ nb_operators[:, None]
+ // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
+ ) % 10
+ marker1 = torch.full((nb, 1), 10)
+ source = torch.randint(10, (nb, self.len_source))
+ marker2 = torch.full((nb, 1), 11)
+ result = operators.bmm(source[:, :, None]).squeeze(-1)
+ print(f"{nb_operators.dtype=} {marker1.dtype=}")
+ sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
+ print(f"{sequences.size()=}")
ar_mask = (sequences == 11).long()
ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
return sequences, ar_mask
def seq2str(self, seq):
- return "".join(self.id2char[x.item()] for x in seq)
+ return "".join("0123456789|>"[x.item()] for x in seq)
+
+
+class ProblemLevel2(Problem):
+ def __init__(self, len_source=5, len_result=8):
+ self.len_source = len_source
+ self.len_result = len_result
+
+ def generate_sequences(self, nb):
+ operators = F.one_hot(
+ torch.rand(nb, self.len_result, self.len_source).argmax(-1),
+ num_classes=self.len_source,
+ )
+ source1 = torch.randint(10, (nb, self.len_source))
+ marker1 = torch.full((nb, 1), 10)
+ result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
+ marker2 = torch.full((nb, 1), 11)
+ source2 = torch.randint(10, (nb, self.len_source))
+ marker3 = torch.full((nb, 1), 12)
+ result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
+
+ sequences = torch.cat(
+ (source1, marker1, result1, marker2, source2, marker3, result2), 1
+ )
+ ar_mask = (sequences == 12).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join("0123456789>|~"[x.item()] for x in seq)
####################