5 import torch, torchvision
8 from torch.nn import functional as F
10 ######################################################################
14 def generate_sequences(self, nb):
17 def seq2str(self, seq):
18 return "[NOT IMPLEMENTED]"
20 def compute_nb_correct(self, input, ar_mask, result):
21 nb_total = ar_mask.sum().item()
22 nb_correct = ((result == input).long() * ar_mask).sum().item()
23 return nb_total, nb_correct
27 class ProblemDegradation(Problem):
28 def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False):
29 assert value_max // nb_state_tokens >= 2
30 self.nb_state_tokens = nb_state_tokens
31 self.nb_time_steps = nb_time_steps
32 self.value_max = value_max
35 def generate_sequences(self, nb):
37 torch.rand(nb, self.nb_state_tokens).sort(dim=-1).indices == 0
38 ).long() * self.value_max
41 for t in range(self.nb_time_steps - 1):
42 v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long()
43 u = (v.max(dim=-1, keepdim=True).values == v).long()
46 .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size()))
47 .sum(dim=-1, keepdim=True)
49 m = 1 + ((n - 1) * torch.rand(n.size())).long()
52 + m * u.roll(shifts=-1, dims=-1)
54 + (n - m) * u.roll(shifts=1, dims=-1)
61 seq = torch.cat(seq, dim=1)
62 return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
64 def compute_nb_correct(self, input, ar_mask, result):
65 nb_total = result.size(0)
67 e = result.new_zeros(self.nb_state_tokens)
70 states = list(seq.split(self.nb_state_tokens))
75 j = d.sort(descending=True).indices[0]
78 if (d - e).abs().sum() == 0:
80 for k in range(len(states) - 1):
81 d = states[k + 1] - states[k]
82 j = d.sort(descending=False).indices[0]
85 or d[j] > self.value_max // 4
86 or d[(j + 1) % e.size(0)] <= 0
87 or d[(j + 1) % e.size(0)] >= -d[j]
93 e[(j + 1) % e.size(0)] = d[(j + 1) % e.size(0)]
94 e[(j - 1) % e.size(0)] = -d[(j + 1) % e.size(0)] - d[j]
95 if (d - e).abs().sum() > 0:
100 return nb_total, nb_correct
102 def seq2str(self, seq):
104 [" ".join([f"{x:02d}" for x in s]) for s in seq.split(self.nb_state_tokens)]
111 class ProblemTwoTargets(Problem):
112 def __init__(self, len_total=10, len_targets=3):
113 assert len_targets >= 3
114 assert len_total >= 3 * len_targets - 1
115 self.len_total = len_total
116 self.len_targets = len_targets
118 def generate_sequences(self, nb):
119 k = torch.arange(self.len_total)[None, :]
120 s = torch.randint(10, (nb, self.len_total))
121 l = torch.rand(nb, self.len_total)
122 l = l * (k <= self.len_total - self.len_targets).long()
123 k1 = l.argmax(dim=1, keepdim=True)
124 m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
125 s = s * m + 10 * (1 - m)
128 - (k + self.len_targets - 1 >= k1).long()
129 * (k < k1 + self.len_targets).long()
131 k2 = l.argmax(dim=1, keepdim=True)
132 m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
133 s = s * m + 11 * (1 - m)
134 a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
135 a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
136 sequences = torch.cat(
139 torch.full((nb, 1), 12),
141 torch.full((nb, 1), 12),
143 torch.full((nb, 1), 12),
147 ar_mask = (sequences == 12).long()
148 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
149 return sequences, ar_mask
151 def seq2str(self, seq):
152 return "".join("0123456789-+|"[x.item()] for x in seq)
158 class ProblemByHeart(Problem):
159 def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
160 self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
161 self.seq[:, len_prompt] = 10
163 def generate_sequences(self, nb):
164 sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
165 ar_mask = (sequences == 10).long()
166 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
167 return sequences, ar_mask
169 def seq2str(self, seq):
170 return "".join("0123456789|"[x.item()] for x in seq)
176 class ProblemLearnOperator(Problem):
177 def __init__(self, nb_operators=100, len_source=6, len_result=9):
178 self.len_source = len_source
179 self.len_result = len_result
180 self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
181 self.operators = F.one_hot(
182 torch.rand(nb_operators, len_result, len_source).argmax(-1),
183 num_classes=len_source,
186 def generate_sequences(self, nb):
187 nb_operators = torch.randint(self.operators.size(0), (nb,))
188 operators = self.operators[nb_operators]
190 nb_operators[:, None]
191 // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
193 marker1 = torch.full((nb, 1), 10)
194 source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
195 marker2 = torch.full((nb, 1), 11)
196 result = operators.bmm(source[:, :, None]).squeeze(-1)
197 sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
198 ar_mask = (sequences == 11).long()
199 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
200 return sequences, ar_mask
202 def seq2str(self, seq):
203 return "".join("0123456789|>"[x.item()] for x in seq)
209 class ProblemGuessOperator(Problem):
210 def __init__(self, len_source=5, len_result=8):
211 self.len_source = len_source
212 self.len_result = len_result
214 def generate_sequences(self, nb):
215 operators = F.one_hot(
216 torch.rand(nb, self.len_result, self.len_source).argmax(-1),
217 num_classes=self.len_source,
219 source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
220 marker1 = torch.full((nb, 1), 10)
221 result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
222 marker2 = torch.full((nb, 1), 11)
223 source2 = torch.randint(10, (nb, self.len_source))
224 marker3 = torch.full((nb, 1), 12)
225 result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
227 sequences = torch.cat(
228 (source1, marker1, result1, marker2, source2, marker3, result2), 1
230 ar_mask = (sequences == 12).long()
231 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
232 return sequences, ar_mask
234 def seq2str(self, seq):
235 return "".join("0123456789>|~"[x.item()] for x in seq)
241 class ProblemAddition(Problem):
242 def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
243 self.nb_digits = nb_digits
244 self.zero_padded = zero_padded
245 self.inverted_result = inverted_result
246 self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
247 self.id2char = dict([(n, c) for c, n in self.char2id.items()])
249 def tensorize(self, strings):
250 len_max = max([len(x) for x in strings])
255 [self.char2id[c] for c in s + "$" * (len_max - len(s))]
263 def generate_sequences(self, nb):
266 a, b = torch.randint(10**self.nb_digits, (2,))
268 a, b, c = str(a.item()), str(b.item()), str(c.item())
270 a = "0" * (self.nb_digits - len(a)) + a
271 b = "0" * (self.nb_digits - len(b)) + b
272 c = "0" * (self.nb_digits + 1 - len(c)) + c
273 if self.inverted_result:
275 sequences.append(f"{a}+{b}={c}$")
277 sequences = self.tensorize(sequences)
278 ar_mask = (sequences == self.char2id["="]).long()
279 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
280 return sequences, ar_mask
282 def seq2str(self, seq):
283 return "".join(self.id2char[x.item()] for x in seq)
289 class ProblemMixing(Problem):
290 def __init__(self, height=4, width=4, nb_time_steps=9, hard=False):
293 self.nb_time_steps = nb_time_steps
296 def start_random(self, nb):
297 y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
299 # m = (torch.rand(y.size()).sort(dim=-1).indices < y.size(1) // 2).long()
301 i = torch.arange(self.height).reshape(1,-1,1).expand(nb,self.height,self.width)
302 j = torch.arange(self.width).reshape(1,1,-1).expand(nb,self.height,self.width)
304 ri = torch.randint(self.height, (nb,)).reshape(nb,1,1)
305 rj = torch.randint(self.width, (nb,)).reshape(nb,1,1)
307 m = 1 - torch.logical_or(i==ri,j==rj).long().flatten(1)
309 y = (y * m + self.height * self.width * (1 - m)).reshape(
310 nb, self.height, self.width
315 def start_error(self, x):
316 i = torch.arange(self.height, device=x.device).reshape(1,-1,1).expand_as(x)
317 j = torch.arange(self.width, device=x.device).reshape(1,1,-1).expand_as(x)
319 ri = (x == self.height * self.width).long().sum(dim=-1).argmax(-1).view(-1,1,1)
320 rj = (x == self.height * self.width).long().sum(dim=-2).argmax(-1).view(-1,1,1)
322 m = 1 - torch.logical_or(i==ri,j==rj).long().flatten(1)
325 u = torch.arange(self.height * self.width, device = x.device).reshape(1, -1)
327 d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1)
333 .expand(-1, self.height * 2 + self.width * 2, -1, -1)
338 for i in range(self.height):
339 y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=-1)
341 y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=1)
344 for j in range(self.width):
345 y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=-1)
347 y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=1)
352 def generate_sequences(self, nb):
353 x = self.start_random(nb)
357 for t in range(self.nb_time_steps - 1):
359 x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
360 seq.append(x.flatten(1))
365 seq = torch.cat(seq, dim=1)
366 return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
368 def compute_nb_correct(self, input, ar_mask, result):
370 x.reshape(result.size(0), self.height, self.width)
371 for x in result.split(self.height * self.width, dim=1)
378 d = self.start_error(x)
380 for t in range(self.nb_time_steps - 1):
381 x0, x = a[t], a[t + 1]
383 d = d + (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
385 nb_total, nb_correct = result.size(0), (d == 0).long().sum().item()
387 return nb_total, nb_correct
389 def seq2str(self, seq):
393 ["-".join([f"{x:02d}" if x < self.height * self.width else "**" for x in s]) for s in r.split(self.width)]
395 for r in seq.split(self.height * self.width)
402 if __name__ == "__main__":
404 s, m = p.generate_sequences(10000)
407 print(p.compute_nb_correct(None, None, s))