Oups
[picoclvr.git] / problems.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 ######################################################################
11
12
13 class Problem:
14     def generate_sequences(self, nb):
15         pass
16
17     def seq2str(self, seq):
18         return "[NOT IMPLEMENTED]"
19
20     def compute_nb_correct(self, input, ar_mask, result):
21         nb_total = ar_mask.sum().item()
22         nb_correct = ((result == input).long() * ar_mask).sum().item()
23         return nb_total, nb_correct
24
25
26 ####################
27
28
29 class ProblemDegradation(Problem):
30     def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False):
31         assert value_max // nb_state_tokens >= 2
32         self.nb_state_tokens = nb_state_tokens
33         self.nb_time_steps = nb_time_steps
34         self.value_max = value_max
35         self.hard = hard
36
37     def generate_sequences(self, nb):
38         x = (
39             torch.rand(nb, self.nb_state_tokens).sort(dim=-1).indices == 0
40         ).long() * self.value_max
41         seq = [x]
42
43         for t in range(self.nb_time_steps - 1):
44             v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long()
45             u = (v.max(dim=-1, keepdim=True).values == v).long()
46             n = (
47                 (u * x)
48                 .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size()))
49                 .sum(dim=-1, keepdim=True)
50             )
51             m = 1 + ((n - 1) * torch.rand(n.size())).long()
52             x = (
53                 x
54                 + m * u.roll(shifts=-1, dims=-1)
55                 - n * u
56                 + (n - m) * u.roll(shifts=1, dims=-1)
57             )
58             seq.append(x)
59
60         if self.hard:
61             seq.reverse()
62
63         seq = torch.cat(seq, dim=1)
64         return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
65
66     def compute_nb_correct(self, input, ar_mask, result):
67         nb_total = result.size(0)
68         nb_correct = 0
69         e = result.new_zeros(self.nb_state_tokens)
70
71         for seq in result:
72             states = list(seq.split(self.nb_state_tokens))
73             if self.hard:
74                 states.reverse()
75
76             d = states[0]
77             j = d.sort(descending=True).indices[0]
78             e.zero_()
79             e[j] = self.value_max
80             if (d - e).abs().sum() == 0:
81                 nb_errors = 0
82                 for k in range(len(states) - 1):
83                     d = states[k + 1] - states[k]
84                     j = d.sort(descending=False).indices[0]
85                     if (
86                         d[j] == 0
87                         or d[j] > self.value_max // 4
88                         or d[(j + 1) % e.size(0)] <= 0
89                         or d[(j + 1) % e.size(0)] >= -d[j]
90                     ):
91                         nb_errors += 1
92                     else:
93                         e.zero_()
94                         e[j] = d[j]
95                         e[(j + 1) % e.size(0)] = d[(j + 1) % e.size(0)]
96                         e[(j - 1) % e.size(0)] = -d[(j + 1) % e.size(0)] - d[j]
97                         if (d - e).abs().sum() > 0:
98                             nb_errors += 1
99                 if nb_errors == 0:
100                     nb_correct += 1
101
102         return nb_total, nb_correct
103
104     def seq2str(self, seq):
105         return " | ".join(
106             [" ".join([f"{x:02d}" for x in s]) for s in seq.split(self.nb_state_tokens)]
107         )
108
109
110 ####################
111
112
113 class ProblemMemory(Problem):
114     def __init__(self, len_total=25):
115         self.len_total = len_total
116         self.max_len_pattern = 5
117         self.nb_noise_tokens = 10
118         self.start_pattern_token = 0
119         self.end_pattern_token = 1
120         self.start_result_token = 2
121         self.end_result_token = 3
122         self.token_string = "[]<>" + "".join(
123             [chr(ord("a") + k) for k in range(self.nb_noise_tokens)]
124         )
125
126     def generate_sequences(self, nb):
127         sequences = (
128             torch.randint(self.nb_noise_tokens, (nb, self.len_total))
129             + self.end_result_token
130             + 1
131         )
132         len_patterns = torch.randint(self.max_len_pattern, (nb,)) + 1
133         pattern_positions = torch.randint(
134             self.len_total - (5 + 2 * self.max_len_pattern), (nb,)
135         )
136         k = self.len_total - (3 + self.max_len_pattern)
137         for i in range(nb):
138             l = len_patterns[i]
139             j = pattern_positions[i]
140             sequences[i, j] = self.start_pattern_token
141             sequences[i, j + l + 2] = self.end_pattern_token
142             sequences[i, k] = self.start_result_token
143             sequences[i, k + l + 2] = self.end_result_token
144             sequences[i, k + 1 : k + 2 + l] = sequences[i, j + 1 : j + 2 + l]
145
146         j = torch.arange(self.len_total)[None, :]
147         ar_mask = (j > k).long() * (j <= k + 1 + len_patterns[:, None]).long()
148
149         return sequences, ar_mask
150
151     def seq2str(self, seq):
152         return "".join(self.token_string[x.item()] for x in seq)
153
154
155 class ProblemTwoTargets(Problem):
156     def __init__(self, len_total=10, len_targets=3):
157         assert len_targets >= 3
158         assert len_total >= 3 * len_targets - 1
159         self.len_total = len_total
160         self.len_targets = len_targets
161
162     def generate_sequences(self, nb):
163         k = torch.arange(self.len_total)[None, :]
164         s = torch.randint(10, (nb, self.len_total))
165         l = torch.rand(nb, self.len_total)
166         l = l * (k <= self.len_total - self.len_targets).long()
167         k1 = l.argmax(dim=1, keepdim=True)
168         m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
169         s = s * m + 10 * (1 - m)
170         l = l * (
171             1
172             - (k + self.len_targets - 1 >= k1).long()
173             * (k < k1 + self.len_targets).long()
174         )
175         k2 = l.argmax(dim=1, keepdim=True)
176         m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
177         s = s * m + 11 * (1 - m)
178         a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
179         a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
180         sequences = torch.cat(
181             (
182                 s,
183                 torch.full((nb, 1), 12),
184                 a1,
185                 torch.full((nb, 1), 12),
186                 a2,
187                 torch.full((nb, 1), 12),
188             ),
189             1,
190         )
191         ar_mask = (sequences == 12).long()
192         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
193         return sequences, ar_mask
194
195     def seq2str(self, seq):
196         return "".join("0123456789-+|"[x.item()] for x in seq)
197
198
199 ####################
200
201
202 class ProblemByHeart(Problem):
203     def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
204         self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
205         self.seq[:, len_prompt] = 10
206
207     def generate_sequences(self, nb):
208         sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
209         ar_mask = (sequences == 10).long()
210         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
211         return sequences, ar_mask
212
213     def seq2str(self, seq):
214         return "".join("0123456789|"[x.item()] for x in seq)
215
216
217 ####################
218
219
220 class ProblemLearnOperator(Problem):
221     def __init__(self, nb_operators=100, len_source=6, len_result=9):
222         self.len_source = len_source
223         self.len_result = len_result
224         self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
225         self.operators = F.one_hot(
226             torch.rand(nb_operators, len_result, len_source).argmax(-1),
227             num_classes=len_source,
228         )
229
230     def generate_sequences(self, nb):
231         nb_operators = torch.randint(self.operators.size(0), (nb,))
232         operators = self.operators[nb_operators]
233         nb_operators = (
234             nb_operators[:, None]
235             // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
236         ) % 10
237         marker1 = torch.full((nb, 1), 10)
238         source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
239         marker2 = torch.full((nb, 1), 11)
240         result = operators.bmm(source[:, :, None]).squeeze(-1)
241         sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
242         ar_mask = (sequences == 11).long()
243         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
244         return sequences, ar_mask
245
246     def seq2str(self, seq):
247         return "".join("0123456789|>"[x.item()] for x in seq)
248
249
250 ####################
251
252
253 class ProblemGuessOperator(Problem):
254     def __init__(self, len_source=5, len_result=8):
255         self.len_source = len_source
256         self.len_result = len_result
257
258     def generate_sequences(self, nb):
259         operators = F.one_hot(
260             torch.rand(nb, self.len_result, self.len_source).argmax(-1),
261             num_classes=self.len_source,
262         )
263         source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
264         marker1 = torch.full((nb, 1), 10)
265         result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
266         marker2 = torch.full((nb, 1), 11)
267         source2 = torch.randint(10, (nb, self.len_source))
268         marker3 = torch.full((nb, 1), 12)
269         result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
270
271         sequences = torch.cat(
272             (source1, marker1, result1, marker2, source2, marker3, result2), 1
273         )
274         ar_mask = (sequences == 12).long()
275         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
276         return sequences, ar_mask
277
278     def seq2str(self, seq):
279         return "".join("0123456789>|~"[x.item()] for x in seq)
280
281
282 ####################
283
284
285 class ProblemAddition(Problem):
286     def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
287         self.nb_digits = nb_digits
288         self.zero_padded = zero_padded
289         self.inverted_result = inverted_result
290         self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
291         self.id2char = dict([(n, c) for c, n in self.char2id.items()])
292
293     def tensorize(self, strings):
294         len_max = max([len(x) for x in strings])
295         return torch.cat(
296             [
297                 torch.tensor(
298                     [
299                         [self.char2id[c] for c in s + "$" * (len_max - len(s))]
300                         for s in strings
301                     ]
302                 )
303             ],
304             0,
305         )
306
307     def generate_sequences(self, nb):
308         sequences = []
309         for k in range(nb):
310             a, b = torch.randint(10**self.nb_digits, (2,))
311             c = a + b
312             a, b, c = str(a.item()), str(b.item()), str(c.item())
313             if self.zero_padded:
314                 a = "0" * (self.nb_digits - len(a)) + a
315                 b = "0" * (self.nb_digits - len(b)) + b
316                 c = "0" * (self.nb_digits + 1 - len(c)) + c
317             if self.inverted_result:
318                 c = c[::-1]
319             sequences.append(f"{a}+{b}={c}$")
320
321         sequences = self.tensorize(sequences)
322         ar_mask = (sequences == self.char2id["="]).long()
323         ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
324         return sequences, ar_mask
325
326     def seq2str(self, seq):
327         return "".join(self.id2char[x.item()] for x in seq)
328
329
330 ####################
331
332
333 class ProblemMixing(Problem):
334     def __init__(
335         self, height=4, width=4, nb_time_steps=9, hard=False, random_start=True
336     ):
337         self.height = height
338         self.width = width
339         self.nb_time_steps = nb_time_steps
340         self.hard = hard
341         self.random_start = random_start
342
343     def start_random(self, nb):
344         y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
345
346         if self.random_start:
347             i = (
348                 torch.arange(self.height)
349                 .reshape(1, -1, 1)
350                 .expand(nb, self.height, self.width)
351             )
352             j = (
353                 torch.arange(self.width)
354                 .reshape(1, 1, -1)
355                 .expand(nb, self.height, self.width)
356             )
357
358             ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1)
359             rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1)
360
361             m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
362
363             y = y * m + self.height * self.width * (1 - m)
364
365         y = y.reshape(nb, self.height, self.width)
366
367         return y
368
369     def start_error(self, x):
370         if self.random_start:
371             i = (
372                 torch.arange(self.height, device=x.device)
373                 .reshape(1, -1, 1)
374                 .expand_as(x)
375             )
376             j = torch.arange(self.width, device=x.device).reshape(1, 1, -1).expand_as(x)
377
378             ri = (
379                 (x == self.height * self.width)
380                 .long()
381                 .sum(dim=-1)
382                 .argmax(-1)
383                 .view(-1, 1, 1)
384             )
385             rj = (
386                 (x == self.height * self.width)
387                 .long()
388                 .sum(dim=-2)
389                 .argmax(-1)
390                 .view(-1, 1, 1)
391             )
392
393             m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
394         else:
395             m = 1
396
397         x = x.flatten(1)
398         u = torch.arange(self.height * self.width, device=x.device).reshape(1, -1)
399
400         d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1)
401
402         return d
403
404     def moves(self, x):
405         y = (
406             x[:, None, :, :]
407             .expand(-1, self.height * 2 + self.width * 2, -1, -1)
408             .clone()
409         )
410         k = 0
411
412         for i in range(self.height):
413             y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=-1)
414             k += 1
415             y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=1)
416             k += 1
417
418         for j in range(self.width):
419             y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=-1)
420             k += 1
421             y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=1)
422             k += 1
423
424         return y
425
426     def generate_sequences(self, nb):
427         x = self.start_random(nb)
428
429         seq = [x.flatten(1)]
430
431         for t in range(self.nb_time_steps - 1):
432             y = self.moves(x)
433             x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
434             seq.append(x.flatten(1))
435
436         if self.hard:
437             seq.reverse()
438
439         seq = torch.cat(seq, dim=1)
440         return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
441
442     def compute_nb_correct(self, input, ar_mask, result):
443         a = [
444             x.reshape(result.size(0), self.height, self.width)
445             for x in result.split(self.height * self.width, dim=1)
446         ]
447         if self.hard:
448             a.reverse()
449
450         x = a[0]
451
452         d = self.start_error(x)
453
454         for t in range(self.nb_time_steps - 1):
455             x0, x = a[t], a[t + 1]
456             y = self.moves(x0)
457             d = d + (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
458
459         nb_total, nb_correct = result.size(0), (d == 0).long().sum().item()
460
461         return nb_total, nb_correct
462
463     def seq2str(self, seq):
464         return " | ".join(
465             [
466                 " ".join(
467                     [
468                         "-".join(
469                             [
470                                 f"{x:02d}" if x < self.height * self.width else "**"
471                                 for x in s
472                             ]
473                         )
474                         for s in r.split(self.width)
475                     ]
476                 )
477                 for r in seq.split(self.height * self.width)
478             ]
479         )
480
481
482 ####################
483
484 if __name__ == "__main__":
485     p = ProblemMixing(height=3, width=3, random_start=False)
486
487     s, m = p.generate_sequences(10000)
488     for x in s[:5]:
489         print(p.seq2str(x))
490     print(p.compute_nb_correct(None, None, s))