5 import torch, torchvision
8 from torch.nn import functional as F
10 ######################################################################
14 def generate_sequences(self, nb):
17 def seq2str(self, seq):
18 return "[NOT IMPLEMENTED]"
24 class ProblemTwoTargets(Problem):
25 def __init__(self, len_total=10, len_target=2):
26 assert len_total >= 3 * (2 + len_target) - 1
27 self.len_total = len_total
28 self.len_target = len_target
30 def generate_sequences(self, nb):
31 k = torch.arange(self.len_total)[None, :]
32 l = torch.randint(self.len_total, (2, nb))[:, :, None] + 1
33 i = torch.randint(10, (2, nb))[:, :, None]
36 c = l[0] + 1 + l[1] + 1 + l[0]
40 + (k > a) * (k < b) * i[1]
42 + (k > b) * (k < c) * i[1]
45 ar_mask = (sequences == 11).long()
46 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
47 return sequences, ar_mask
49 def seq2str(self, seq):
50 return "".join("0123456789|>_"[x.item()] for x in seq)
56 class ProblemLenId(Problem):
57 def __init__(self, len_max=10):
58 self.len_max = len_max
60 def generate_sequences(self, nb):
61 k = torch.arange(self.len_max * 3 + 3)[None, :]
62 l = torch.randint(self.len_max, (2, nb))[:, :, None] + 1
63 i = torch.randint(10, (2, nb))[:, :, None]
66 c = l[0] + 1 + l[1] + 1 + l[0]
70 + (k > a) * (k < b) * i[1]
72 + (k > b) * (k < c) * i[1]
75 ar_mask = (sequences == 11).long()
76 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
77 return sequences, ar_mask
79 def seq2str(self, seq):
80 return "".join("0123456789|>_"[x.item()] for x in seq)
86 class ProblemLevel0(Problem):
87 def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
88 self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
89 self.seq[:, len_prompt] = 10
91 def generate_sequences(self, nb):
92 sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
93 ar_mask = (sequences == 10).long()
94 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
95 return sequences, ar_mask
97 def seq2str(self, seq):
98 return "".join("0123456789|"[x.item()] for x in seq)
104 class ProblemLevel1(Problem):
105 def __init__(self, nb_operators=100, len_source=5, len_result=8):
106 self.len_source = len_source
107 self.len_result = len_result
108 self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
109 self.operators = F.one_hot(
110 torch.rand(nb_operators, len_result, len_source).argmax(-1),
111 num_classes=len_source,
114 def generate_sequences(self, nb):
115 nb_operators = torch.randint(self.operators.size(0), (nb,))
116 operators = self.operators[nb_operators]
118 nb_operators[:, None]
119 // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
121 marker1 = torch.full((nb, 1), 10)
122 # source = torch.randint(10, (nb, self.len_source))
123 source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
124 marker2 = torch.full((nb, 1), 11)
125 result = operators.bmm(source[:, :, None]).squeeze(-1)
126 sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
127 ar_mask = (sequences == 11).long()
128 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
129 return sequences, ar_mask
131 def seq2str(self, seq):
132 return "".join("0123456789|>"[x.item()] for x in seq)
138 class ProblemLevel2(Problem):
139 def __init__(self, len_source=5, len_result=8):
140 self.len_source = len_source
141 self.len_result = len_result
143 def generate_sequences(self, nb):
144 operators = F.one_hot(
145 torch.rand(nb, self.len_result, self.len_source).argmax(-1),
146 num_classes=self.len_source,
148 source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
149 marker1 = torch.full((nb, 1), 10)
150 result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
151 marker2 = torch.full((nb, 1), 11)
152 source2 = torch.randint(10, (nb, self.len_source))
153 marker3 = torch.full((nb, 1), 12)
154 result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
156 sequences = torch.cat(
157 (source1, marker1, result1, marker2, source2, marker3, result2), 1
159 ar_mask = (sequences == 12).long()
160 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
161 return sequences, ar_mask
163 def seq2str(self, seq):
164 return "".join("0123456789>|~"[x.item()] for x in seq)
170 class ProblemAddition(Problem):
171 def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
172 self.nb_digits = nb_digits
173 self.zero_padded = zero_padded
174 self.inverted_result = inverted_result
175 self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
176 self.id2char = dict([(n, c) for c, n in self.char2id.items()])
178 def tensorize(self, strings):
179 len_max = max([len(x) for x in strings])
184 [self.char2id[c] for c in s + "$" * (len_max - len(s))]
192 def generate_sequences(self, nb):
195 a, b = torch.randint(10**self.nb_digits, (2,))
197 a, b, c = str(a.item()), str(b.item()), str(c.item())
199 a = "0" * (self.nb_digits - len(a)) + a
200 b = "0" * (self.nb_digits - len(b)) + b
201 c = "0" * (self.nb_digits + 1 - len(c)) + c
202 if self.inverted_result:
204 sequences.append(f"{a}+{b}={c}$")
206 sequences = self.tensorize(sequences)
207 ar_mask = (sequences == self.char2id["="]).long()
208 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
209 return sequences, ar_mask
211 def seq2str(self, seq):
212 return "".join(self.id2char[x.item()] for x in seq)
215 # class ProblemUnion(Problem):
216 # problems = [ProblemByheart()]
217 # nb_common_codes = 100
219 # def generate_sequences(nb_samples):
220 # problem_indexes = torch.randint(len(problems), (nb_samples,))
221 # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
222 # print(f"{nb_samples_per_problem}")
224 # for nb, p in zip(nb_samples_per_problem, problems):
225 # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
228 # for strain, stest in zip(train_seq, test_seq):
229 # s = torch.cat((strain, stest), 0)