5 import torch, torchvision
8 from torch.nn import functional as F
10 ######################################################################
14 def generate_sequences(self, nb):
17 def seq2str(self, seq):
18 return "[NOT IMPLEMENTED]"
20 def compute_nb_correct(self, input, ar_mask, result):
21 nb_total = ar_mask.sum().item()
22 nb_correct = ((result == input).long() * ar_mask).sum().item()
23 return nb_total, nb_correct
28 class ProblemDegradation(Problem):
29 def __init__(self, nb_state_tokens=5, nb_time_steps=5, value_max=25, hard=False):
30 self.nb_state_tokens = nb_state_tokens
31 self.nb_time_steps = nb_time_steps
32 self.value_max = value_max
35 def generate_sequences(self,nb):
37 x = (torch.rand(nb,self.nb_state_tokens).sort(dim=-1).indices == 0).long() * self.value_max
40 for t in range(self.nb_time_steps-1):
41 v = torch.rand(x.size()) * (x > 0).float()
42 u = (v.max(dim=-1,keepdim=True).values == v).long()
43 n = (u*x*torch.rand(x.size())).long().sum(dim=-1,keepdim=True) // 2
44 x = x + n * (u.roll(shifts=-1,dims=-1) - 2 * u + u.roll(shifts=1,dims=-1))
47 if self.hard: seq.reverse()
49 seq = torch.cat(seq,dim=1)
50 return seq,seq.new_full(seq.size(), 1, dtype=torch.int64)
52 def compute_nb_correct(self, input, ar_mask, result):
53 nb_total = result.size(0)
57 states = list(seq.split(self.nb_state_tokens))
62 j=d.sort(descending=True).indices[0]
63 e=d.new_zeros(d.size())
65 if (d-e).abs().sum() == 0:
67 for k in range(len(states)-1):
68 d=states[k]-states[k+1]
69 j=d.sort(descending=True).indices[0]
70 e=d.new_zeros(d.size())
72 e[(j+1)%e.size(0)]=-d[j]//2
73 e[(j-1)%e.size(0)]=-d[j]//2
74 if (d-e).abs().sum() > 0:
79 return nb_total, nb_correct
81 def seq2str(self, seq):
82 return " | ".join( [ " ".join([f"{x:02d}" for x in s ]) for s in seq.split(self.nb_state_tokens) ] )
87 class ProblemTwoTargets(Problem):
88 def __init__(self, len_total=10, len_targets=3):
89 assert len_targets >= 3
90 assert len_total >= 3 * len_targets - 1
91 self.len_total = len_total
92 self.len_targets = len_targets
94 def generate_sequences(self, nb):
95 k = torch.arange(self.len_total)[None, :]
96 s = torch.randint(10, (nb, self.len_total))
97 l = torch.rand(nb, self.len_total)
98 l = l * (k <= self.len_total - self.len_targets).long()
99 k1 = l.argmax(dim=1, keepdim=True)
100 m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
101 s = s * m + 10 * (1 - m)
104 - (k + self.len_targets - 1 >= k1).long()
105 * (k < k1 + self.len_targets).long()
107 k2 = l.argmax(dim=1, keepdim=True)
108 m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
109 s = s * m + 11 * (1 - m)
110 a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
111 a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
112 sequences = torch.cat(
115 torch.full((nb, 1), 12),
117 torch.full((nb, 1), 12),
119 torch.full((nb, 1), 12),
123 ar_mask = (sequences == 12).long()
124 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
125 return sequences, ar_mask
127 def seq2str(self, seq):
128 return "".join("0123456789-+|"[x.item()] for x in seq)
134 class ProblemByHeart(Problem):
135 def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
136 self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
137 self.seq[:, len_prompt] = 10
139 def generate_sequences(self, nb):
140 sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
141 ar_mask = (sequences == 10).long()
142 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
143 return sequences, ar_mask
145 def seq2str(self, seq):
146 return "".join("0123456789|"[x.item()] for x in seq)
152 class ProblemLearnOperator(Problem):
153 def __init__(self, nb_operators=100, len_source=6, len_result=9):
154 self.len_source = len_source
155 self.len_result = len_result
156 self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
157 self.operators = F.one_hot(
158 torch.rand(nb_operators, len_result, len_source).argmax(-1),
159 num_classes=len_source,
162 def generate_sequences(self, nb):
163 nb_operators = torch.randint(self.operators.size(0), (nb,))
164 operators = self.operators[nb_operators]
166 nb_operators[:, None]
167 // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
169 marker1 = torch.full((nb, 1), 10)
170 source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
171 marker2 = torch.full((nb, 1), 11)
172 result = operators.bmm(source[:, :, None]).squeeze(-1)
173 sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
174 ar_mask = (sequences == 11).long()
175 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
176 return sequences, ar_mask
178 def seq2str(self, seq):
179 return "".join("0123456789|>"[x.item()] for x in seq)
185 class ProblemGuessOperator(Problem):
186 def __init__(self, len_source=5, len_result=8):
187 self.len_source = len_source
188 self.len_result = len_result
190 def generate_sequences(self, nb):
191 operators = F.one_hot(
192 torch.rand(nb, self.len_result, self.len_source).argmax(-1),
193 num_classes=self.len_source,
195 source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
196 marker1 = torch.full((nb, 1), 10)
197 result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
198 marker2 = torch.full((nb, 1), 11)
199 source2 = torch.randint(10, (nb, self.len_source))
200 marker3 = torch.full((nb, 1), 12)
201 result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
203 sequences = torch.cat(
204 (source1, marker1, result1, marker2, source2, marker3, result2), 1
206 ar_mask = (sequences == 12).long()
207 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
208 return sequences, ar_mask
210 def seq2str(self, seq):
211 return "".join("0123456789>|~"[x.item()] for x in seq)
217 class ProblemAddition(Problem):
218 def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
219 self.nb_digits = nb_digits
220 self.zero_padded = zero_padded
221 self.inverted_result = inverted_result
222 self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
223 self.id2char = dict([(n, c) for c, n in self.char2id.items()])
225 def tensorize(self, strings):
226 len_max = max([len(x) for x in strings])
231 [self.char2id[c] for c in s + "$" * (len_max - len(s))]
239 def generate_sequences(self, nb):
242 a, b = torch.randint(10**self.nb_digits, (2,))
244 a, b, c = str(a.item()), str(b.item()), str(c.item())
246 a = "0" * (self.nb_digits - len(a)) + a
247 b = "0" * (self.nb_digits - len(b)) + b
248 c = "0" * (self.nb_digits + 1 - len(c)) + c
249 if self.inverted_result:
251 sequences.append(f"{a}+{b}={c}$")
253 sequences = self.tensorize(sequences)
254 ar_mask = (sequences == self.char2id["="]).long()
255 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
256 return sequences, ar_mask
258 def seq2str(self, seq):
259 return "".join(self.id2char[x.item()] for x in seq)
262 if __name__ == "__main__":
263 p = ProblemDegradation(hard=False)
264 s, m = p.generate_sequences(10000)
265 print(p.seq2str(s[0]))
266 print(p.compute_nb_correct(None, None, s))