5 import torch, torchvision
8 from torch.nn import functional as F
10 ######################################################################
14 def generate_sequences(self, nb):
17 def seq2str(self, seq):
18 return "[NOT IMPLEMENTED]"
20 def compute_nb_correct(self, input, ar_mask, result):
21 nb_total = ar_mask.sum().item()
22 nb_correct = ((result == input).long() * ar_mask).sum().item()
23 return nb_total, nb_correct
29 class ProblemDegradation(Problem):
30 def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False):
31 assert value_max // nb_state_tokens >= 2
32 self.nb_state_tokens = nb_state_tokens
33 self.nb_time_steps = nb_time_steps
34 self.value_max = value_max
37 def generate_sequences(self, nb):
39 torch.rand(nb, self.nb_state_tokens).sort(dim=-1).indices == 0
40 ).long() * self.value_max
43 for t in range(self.nb_time_steps - 1):
44 v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long()
45 u = (v.max(dim=-1, keepdim=True).values == v).long()
48 .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size()))
49 .sum(dim=-1, keepdim=True)
51 m = 1 + ((n - 1) * torch.rand(n.size())).long()
54 + m * u.roll(shifts=-1, dims=-1)
56 + (n - m) * u.roll(shifts=1, dims=-1)
63 seq = torch.cat(seq, dim=1)
64 return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
66 def compute_nb_correct(self, input, ar_mask, result):
67 nb_total = result.size(0)
69 e = result.new_zeros(self.nb_state_tokens)
72 states = list(seq.split(self.nb_state_tokens))
77 j = d.sort(descending=True).indices[0]
80 if (d - e).abs().sum() == 0:
82 for k in range(len(states) - 1):
83 d = states[k + 1] - states[k]
84 j = d.sort(descending=False).indices[0]
87 or d[j] > self.value_max // 4
88 or d[(j + 1) % e.size(0)] <= 0
89 or d[(j + 1) % e.size(0)] >= -d[j]
95 e[(j + 1) % e.size(0)] = d[(j + 1) % e.size(0)]
96 e[(j - 1) % e.size(0)] = -d[(j + 1) % e.size(0)] - d[j]
97 if (d - e).abs().sum() > 0:
102 return nb_total, nb_correct
104 def seq2str(self, seq):
106 [" ".join([f"{x:02d}" for x in s]) for s in seq.split(self.nb_state_tokens)]
113 class ProblemMemory(Problem):
114 def __init__(self, len_total=32):
115 self.len_total = len_total
116 self.max_len_pattern = 5
117 self.nb_noise_tokens = 10
118 self.start_pattern_token = 0
119 self.end_pattern_token = 1
120 self.start_result_token = 2
121 self.end_result_token = 3
122 self.token_string = "[]<>" + "".join(
123 [chr(ord("a") + k) for k in range(self.nb_noise_tokens)]
126 def generate_sequences(self, nb):
128 torch.randint(self.nb_noise_tokens, (nb, self.len_total))
129 + self.end_result_token
132 len_patterns = torch.randint(self.max_len_pattern, (nb,)) + 1
133 pattern_positions = torch.randint(
134 self.len_total - (5 + 2 * self.max_len_pattern), (nb,)
136 k = self.len_total - (3 + self.max_len_pattern)
139 j = pattern_positions[i]
140 sequences[i, j] = self.start_pattern_token
141 sequences[i, j + l + 2] = self.end_pattern_token
142 sequences[i, k] = self.start_result_token
143 sequences[i, k + l + 2] = self.end_result_token
144 sequences[i, k + 1 : k + 2 + l] = sequences[i, j + 1 : j + 2 + l]
146 j = torch.arange(self.len_total)[None, :]
147 ar_mask = (j > k).long() * (j <= k + 1 + len_patterns[:, None]).long()
149 return sequences, ar_mask
151 def seq2str(self, seq):
153 if x < len(self.token_string):
154 return self.token_string[x]
158 return "".join(decode(x.item()) for x in seq)
161 class ProblemTwoTargets(Problem):
162 def __init__(self, len_total=10, len_targets=3):
163 assert len_targets >= 3
164 assert len_total >= 3 * len_targets - 1
165 self.len_total = len_total
166 self.len_targets = len_targets
168 def generate_sequences(self, nb):
169 k = torch.arange(self.len_total)[None, :]
170 s = torch.randint(10, (nb, self.len_total))
171 l = torch.rand(nb, self.len_total)
172 l = l * (k <= self.len_total - self.len_targets).long()
173 k1 = l.argmax(dim=1, keepdim=True)
174 m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
175 s = s * m + 10 * (1 - m)
178 - (k + self.len_targets - 1 >= k1).long()
179 * (k < k1 + self.len_targets).long()
181 k2 = l.argmax(dim=1, keepdim=True)
182 m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
183 s = s * m + 11 * (1 - m)
184 a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
185 a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
186 sequences = torch.cat(
189 torch.full((nb, 1), 12),
191 torch.full((nb, 1), 12),
193 torch.full((nb, 1), 12),
197 ar_mask = (sequences == 12).long()
198 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
199 return sequences, ar_mask
201 def seq2str(self, seq):
202 return "".join("0123456789-+|"[x.item()] for x in seq)
208 class ProblemByHeart(Problem):
209 def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
210 self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
211 self.seq[:, len_prompt] = 10
213 def generate_sequences(self, nb):
214 sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
215 ar_mask = (sequences == 10).long()
216 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
217 return sequences, ar_mask
219 def seq2str(self, seq):
220 return "".join("0123456789|"[x.item()] for x in seq)
226 class ProblemLearnOperator(Problem):
227 def __init__(self, nb_operators=100, len_source=6, len_result=9):
228 self.len_source = len_source
229 self.len_result = len_result
230 self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
231 self.operators = F.one_hot(
232 torch.rand(nb_operators, len_result, len_source).argmax(-1),
233 num_classes=len_source,
236 def generate_sequences(self, nb):
237 nb_operators = torch.randint(self.operators.size(0), (nb,))
238 operators = self.operators[nb_operators]
240 nb_operators[:, None]
241 // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
243 marker1 = torch.full((nb, 1), 10)
244 source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
245 marker2 = torch.full((nb, 1), 11)
246 result = operators.bmm(source[:, :, None]).squeeze(-1)
247 sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
248 ar_mask = (sequences == 11).long()
249 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
250 return sequences, ar_mask
252 def seq2str(self, seq):
253 return "".join("0123456789|>"[x.item()] for x in seq)
259 class ProblemGuessOperator(Problem):
260 def __init__(self, len_source=5, len_result=8):
261 self.len_source = len_source
262 self.len_result = len_result
264 def generate_sequences(self, nb):
265 operators = F.one_hot(
266 torch.rand(nb, self.len_result, self.len_source).argmax(-1),
267 num_classes=self.len_source,
269 source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
270 marker1 = torch.full((nb, 1), 10)
271 result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
272 marker2 = torch.full((nb, 1), 11)
273 source2 = torch.randint(10, (nb, self.len_source))
274 marker3 = torch.full((nb, 1), 12)
275 result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
277 sequences = torch.cat(
278 (source1, marker1, result1, marker2, source2, marker3, result2), 1
280 ar_mask = (sequences == 12).long()
281 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
282 return sequences, ar_mask
284 def seq2str(self, seq):
285 return "".join("0123456789>|~"[x.item()] for x in seq)
291 class ProblemAddition(Problem):
292 def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
293 self.nb_digits = nb_digits
294 self.zero_padded = zero_padded
295 self.inverted_result = inverted_result
296 self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
297 self.id2char = dict([(n, c) for c, n in self.char2id.items()])
299 def tensorize(self, strings):
300 len_max = max([len(x) for x in strings])
305 [self.char2id[c] for c in s + "$" * (len_max - len(s))]
313 def generate_sequences(self, nb):
316 a, b = torch.randint(10**self.nb_digits, (2,))
318 a, b, c = str(a.item()), str(b.item()), str(c.item())
320 a = "0" * (self.nb_digits - len(a)) + a
321 b = "0" * (self.nb_digits - len(b)) + b
322 c = "0" * (self.nb_digits + 1 - len(c)) + c
323 if self.inverted_result:
325 sequences.append(f"{a}+{b}={c}$")
327 sequences = self.tensorize(sequences)
328 ar_mask = (sequences == self.char2id["="]).long()
329 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
330 return sequences, ar_mask
332 def seq2str(self, seq):
333 return "".join(self.id2char[x.item()] for x in seq)
339 class ProblemMixing(Problem):
341 self, height=4, width=4, nb_time_steps=9, hard=False, random_start=True
345 self.nb_time_steps = nb_time_steps
347 self.random_start = random_start
349 def start_random(self, nb):
350 y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
352 if self.random_start:
354 torch.arange(self.height)
356 .expand(nb, self.height, self.width)
359 torch.arange(self.width)
361 .expand(nb, self.height, self.width)
364 ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1)
365 rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1)
367 m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
369 y = y * m + self.height * self.width * (1 - m)
371 y = y.reshape(nb, self.height, self.width)
375 def start_error(self, x):
376 if self.random_start:
378 torch.arange(self.height, device=x.device)
382 j = torch.arange(self.width, device=x.device).reshape(1, 1, -1).expand_as(x)
385 (x == self.height * self.width)
392 (x == self.height * self.width)
399 m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
404 u = torch.arange(self.height * self.width, device=x.device).reshape(1, -1)
406 d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1)
413 .expand(-1, self.height * 2 + self.width * 2, -1, -1)
418 for i in range(self.height):
419 y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=-1)
421 y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=1)
424 for j in range(self.width):
425 y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=-1)
427 y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=1)
432 def generate_sequences(self, nb):
433 x = self.start_random(nb)
437 for t in range(self.nb_time_steps - 1):
439 x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
440 seq.append(x.flatten(1))
445 seq = torch.cat(seq, dim=1)
446 return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
448 def compute_nb_correct(self, input, ar_mask, result):
450 x.reshape(result.size(0), self.height, self.width)
451 for x in result.split(self.height * self.width, dim=1)
458 d = self.start_error(x)
460 for t in range(self.nb_time_steps - 1):
461 x0, x = a[t], a[t + 1]
463 d = d + (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
465 nb_total, nb_correct = result.size(0), (d == 0).long().sum().item()
467 return nb_total, nb_correct
469 def seq2str(self, seq):
476 f"{x:02d}" if x < self.height * self.width else "**"
480 for s in r.split(self.width)
483 for r in seq.split(self.height * self.width)
490 if __name__ == "__main__":
491 p = ProblemMixing(height=3, width=3, random_start=False)
493 s, m = p.generate_sequences(10000)
496 print(p.compute_nb_correct(None, None, s))