5 import torch, torchvision
8 from torch.nn import functional as F
10 ######################################################################
14 def generate_sequences(self, nb):
17 def seq2str(self, seq):
18 return "[NOT IMPLEMENTED]"
20 def compute_nb_correct(self, input, ar_mask, result):
21 nb_total = ar_mask.sum().item()
22 nb_correct = ((result == input).long() * ar_mask).sum().item()
23 return nb_total, nb_correct
28 class ProblemTwoCuts(Problem):
29 def __init__(self, len_total=50, nb_values=100, global_constraint=True):
30 self.len_total = len_total
31 self.nb_values = nb_values
32 self.global_constraint = global_constraint
34 def generate_sequences_internal(self, nb):
37 def generate_sequences(self,nb):
39 u = torch.randint(self.len_total, (nb,))
40 v = torch.randint(self.len_total, (nb,))
42 a = torch.randint(self.nb_values, (nb,))
43 b = torch.randint(self.nb_values, (nb,))
44 c = torch.randint(self.nb_values, (nb,))
47 to_compute = torch.logical_or(u>=v-self.len_total//10,u<v-self.len_total//5)
48 to_compute =torch.logical_or(to_compute, u == 0)
49 to_compute =torch.logical_or(to_compute, v == self.len_total)
50 n = to_compute.long().sum()
54 u[to_compute] = torch.randint(self.len_total, (n,))
55 v[to_compute] = torch.randint(self.len_total, (n,))
59 to_compute = torch.logical_or(to_compute,b==c)
60 to_compute = torch.logical_or(to_compute,a==c)
62 if self.global_constraint:
63 to_compute = torch.logical_or(to_compute,(a*u+b*(v-u)+c*(self.len_total-v)) // self.len_total != self.nb_values//2)
65 n = to_compute.long().sum()
69 a[to_compute] = torch.randint(self.nb_values, (n,))
70 b[to_compute] = torch.randint(self.nb_values, (n,))
71 c[to_compute] = torch.randint(self.nb_values, (n,))
73 assert (u>=v).long().sum() == 0
74 assert (a==b).long().sum() == 0
75 assert (a==c).long().sum() == 0
76 assert (c==b).long().sum() == 0
78 t = torch.arange(self.len_total)
79 seq = (t[None,:] < u[:,None]).long() * a[:,None] + \
80 (t[None,:] >= u[:,None]).long() * (t[None,:] < v[:,None]).long() * b[:,None] + \
81 (t[None,:] >= v[:,None]).long() * c[:,None]
83 return seq,seq.new_full(seq.size(), 1, dtype=torch.int64)
85 def compute_nb_correct(self, input, ar_mask, result):
86 nb_total = result.size(0)
88 i = torch.arange(result.size(1), device=result.device)
90 for k in range(nb_total):
93 uu = (s != a).nonzero()
97 vv = torch.logical_and(s != b, i >= u).nonzero()
101 ww = torch.logical_and(s != c, i >= v).nonzero()
103 if not self.global_constraint or (a*u+b*(v-u)+c*(self.len_total-v)) // self.len_total == self.nb_values//2:
106 return nb_total, nb_correct
108 def seq2str(self, seq):
109 return " ".join( [ f"{x:02d}" for x in seq ] )
114 class ProblemTwoTargets(Problem):
115 def __init__(self, len_total=10, len_targets=3):
116 assert len_targets >= 3
117 assert len_total >= 3 * len_targets - 1
118 self.len_total = len_total
119 self.len_targets = len_targets
121 def generate_sequences(self, nb):
122 k = torch.arange(self.len_total)[None, :]
123 s = torch.randint(10, (nb, self.len_total))
124 l = torch.rand(nb, self.len_total)
125 l = l * (k <= self.len_total - self.len_targets).long()
126 k1 = l.argmax(dim=1, keepdim=True)
127 m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
128 s = s * m + 10 * (1 - m)
131 - (k + self.len_targets - 1 >= k1).long()
132 * (k < k1 + self.len_targets).long()
134 k2 = l.argmax(dim=1, keepdim=True)
135 m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
136 s = s * m + 11 * (1 - m)
137 a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
138 a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
139 sequences = torch.cat(
142 torch.full((nb, 1), 12),
144 torch.full((nb, 1), 12),
146 torch.full((nb, 1), 12),
150 ar_mask = (sequences == 12).long()
151 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
152 return sequences, ar_mask
154 def seq2str(self, seq):
155 return "".join("0123456789-+|"[x.item()] for x in seq)
161 class ProblemByHeart(Problem):
162 def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
163 self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
164 self.seq[:, len_prompt] = 10
166 def generate_sequences(self, nb):
167 sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
168 ar_mask = (sequences == 10).long()
169 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
170 return sequences, ar_mask
172 def seq2str(self, seq):
173 return "".join("0123456789|"[x.item()] for x in seq)
179 class ProblemLearnOperator(Problem):
180 def __init__(self, nb_operators=100, len_source=6, len_result=9):
181 self.len_source = len_source
182 self.len_result = len_result
183 self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
184 self.operators = F.one_hot(
185 torch.rand(nb_operators, len_result, len_source).argmax(-1),
186 num_classes=len_source,
189 def generate_sequences(self, nb):
190 nb_operators = torch.randint(self.operators.size(0), (nb,))
191 operators = self.operators[nb_operators]
193 nb_operators[:, None]
194 // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
196 marker1 = torch.full((nb, 1), 10)
197 source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
198 marker2 = torch.full((nb, 1), 11)
199 result = operators.bmm(source[:, :, None]).squeeze(-1)
200 sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
201 ar_mask = (sequences == 11).long()
202 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
203 return sequences, ar_mask
205 def seq2str(self, seq):
206 return "".join("0123456789|>"[x.item()] for x in seq)
212 class ProblemGuessOperator(Problem):
213 def __init__(self, len_source=5, len_result=8):
214 self.len_source = len_source
215 self.len_result = len_result
217 def generate_sequences(self, nb):
218 operators = F.one_hot(
219 torch.rand(nb, self.len_result, self.len_source).argmax(-1),
220 num_classes=self.len_source,
222 source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
223 marker1 = torch.full((nb, 1), 10)
224 result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
225 marker2 = torch.full((nb, 1), 11)
226 source2 = torch.randint(10, (nb, self.len_source))
227 marker3 = torch.full((nb, 1), 12)
228 result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
230 sequences = torch.cat(
231 (source1, marker1, result1, marker2, source2, marker3, result2), 1
233 ar_mask = (sequences == 12).long()
234 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
235 return sequences, ar_mask
237 def seq2str(self, seq):
238 return "".join("0123456789>|~"[x.item()] for x in seq)
244 class ProblemAddition(Problem):
245 def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
246 self.nb_digits = nb_digits
247 self.zero_padded = zero_padded
248 self.inverted_result = inverted_result
249 self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
250 self.id2char = dict([(n, c) for c, n in self.char2id.items()])
252 def tensorize(self, strings):
253 len_max = max([len(x) for x in strings])
258 [self.char2id[c] for c in s + "$" * (len_max - len(s))]
266 def generate_sequences(self, nb):
269 a, b = torch.randint(10**self.nb_digits, (2,))
271 a, b, c = str(a.item()), str(b.item()), str(c.item())
273 a = "0" * (self.nb_digits - len(a)) + a
274 b = "0" * (self.nb_digits - len(b)) + b
275 c = "0" * (self.nb_digits + 1 - len(c)) + c
276 if self.inverted_result:
278 sequences.append(f"{a}+{b}={c}$")
280 sequences = self.tensorize(sequences)
281 ar_mask = (sequences == self.char2id["="]).long()
282 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
283 return sequences, ar_mask
285 def seq2str(self, seq):
286 return "".join(self.id2char[x.item()] for x in seq)
289 if __name__ == "__main__":
290 p = ProblemTwoCuts(12)
291 s, m = p.generate_sequences(10000)
292 print(p.compute_nb_correct(None, None, s))