Update.
[culture.git] / quizz_machine.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 import mygpt
16 from mygpt import BracketedSequence
17
18 ######################################################################
19
20 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
21 # 1s where tokens should be generated. The others are kept
22 # unchanged.
23
24
25 def one_batch_masked_inplace_autoregression(
26     model,
27     input,
28     ar_mask,
29     seq_logproba,
30     temperature=1.0,
31     deterministic_synthesis=False,
32 ):
33     to_generate = (ar_mask.sum(0) > 0).nonzero()
34
35     if to_generate.min() > 0:
36         model(
37             BracketedSequence(input, 0, to_generate.min())
38         )  # Needed to initialize the model's cache
39     for s in range(to_generate.min(), to_generate.max() + 1):
40         output = model(BracketedSequence(input, s, 1)).x
41
42         logits = output[:, s]
43
44         logits = (logits / temperature).log_softmax(dim=-1)
45
46         if deterministic_synthesis:
47             t_next = logits.argmax(-1)
48         else:
49             dist = torch.distributions.categorical.Categorical(logits=logits)
50             t_next = dist.sample()
51
52         all_n = torch.arange(t_next.size(0))
53         seq_logproba += logits[all_n, t_next].sum(dim=-1)
54
55         input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
56
57
58 def masked_inplace_autoregression(
59     model,
60     batch_size,
61     input,
62     ar_mask,
63     seq_logproba,
64     temperature,
65     deterministic_synthesis,
66     forbidden_tokens=None,
67     logit_biases=None,
68     progress_bar_desc=None,
69     device=torch.device("cpu"),
70 ):
71     assert input.size() == ar_mask.size()
72
73     batches = zip(
74         input.split(batch_size),
75         ar_mask.split(batch_size),
76         seq_logproba.split(batch_size),
77     )
78
79     if progress_bar_desc is not None:
80         batches = tqdm.tqdm(
81             batches,
82             dynamic_ncols=True,
83             desc=progress_bar_desc,
84             total=(input.size(0) + batch_size - 1) // batch_size,
85         )
86
87     with torch.autograd.no_grad():
88         t = model.training
89         model.eval()
90
91         for input, ar_mask, seq_logproba in batches:
92             one_batch_masked_inplace_autoregression(
93                 model=model,
94                 input=input,
95                 ar_mask=ar_mask,
96                 seq_logproba=seq_logproba,
97                 temperature=temperature,
98                 deterministic_synthesis=deterministic_synthesis,
99             )
100
101         model.train(t)
102
103
104 ######################################################################
105
106
107 class QuizzMachine:
108     def make_ar_mask(self, input):
109         b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
110         return b.long()[None, :].expand_as(input)
111
112     def generate_token_sequences(self, nb):
113         prompts, answers = self.problem.generate_prompts_and_answers(nb)
114         result = []
115
116         for prompt, answer in zip(prompts, answers):
117             if torch.rand(1) < 0.5:
118                 a = [torch.tensor([self.token_forward]), prompt, answer]
119             else:
120                 a = [torch.tensor([self.token_backward]), answer, prompt]
121
122             result.append(torch.cat(a, dim=0)[None, :])
123
124         return torch.cat(result, dim=0)
125
126     def __init__(
127         self,
128         problem,
129         nb_train_samples,
130         nb_test_samples,
131         batch_size,
132         result_dir,
133         logger,
134         device=torch.device("cpu"),
135     ):
136         super().__init__()
137
138         v = problem.nb_token_values()
139         self.token_forward = v
140         self.token_backward = v + 1
141         self.nb_token_values = v + 2
142
143         self.problem = problem
144         self.batch_size = batch_size
145         self.device = device
146         self.logger = logger
147
148         self.train_w_quizzes = self.generate_token_sequences(nb_train_samples).to(
149             device
150         )
151
152         self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
153
154         self.train_c_quizzes = []
155         self.test_c_quizzes = []
156
157         if result_dir is not None:
158             self.save_quizzes(
159                 result_dir, "culture_w_quizzes", self.train_w_quizzes[:72]
160             )
161
162     def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False):
163         l = (quizzes.size(1) - 1) // 2
164         forward = (quizzes[:, 0] == self.token_forward).long()
165         backward = (quizzes[:, 0] == self.token_backward).long()
166         assert forward.equal(1 - backward)
167         first = quizzes[:, 1 : 1 + l]
168         second = quizzes[:, 1 + l : 1 + 2 * l]
169         prompts = forward[:, None] * first + backward[:, None] * second
170         answers = forward[:, None] * second + backward[:, None] * first
171
172         if prediction:
173             predicted_prompts = backward
174             predicted_answers = forward
175         else:
176             predicted_prompts = None
177             predicted_answers = None
178
179         self.problem.save_quizzes(
180             result_dir,
181             filename_prefix,
182             prompts,
183             answers,
184             predicted_prompts,
185             predicted_answers,
186         )
187
188     def batches(self, split="train", desc=None):
189         assert split in {"train", "test"}
190         if split == "train":
191             w_quizzes = self.train_w_quizzes
192             c_quizzes = self.train_c_quizzes
193         else:
194             w_quizzes = self.test_w_quizzes
195             c_quizzes = self.test_c_quizzes
196
197         if len(c_quizzes) > 0:
198             c_quizzes = torch.cat(c_quizzes, dim=0)
199             if c_quizzes.size(0) > w_quizzes.size(0) // 2:
200                 i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
201                 c_quizzes = c_quizzes[i]
202
203             i = torch.randperm(w_quizzes.size(0))[
204                 : w_quizzes.size(0) - c_quizzes.size(0)
205             ]
206             w_quizzes = w_quizzes[i]
207
208             self.nb_batch_w_quizzes = w_quizzes.size(0)
209             self.nb_batch_c_quizzes = c_quizzes.size(0)
210
211             input = torch.cat([w_quizzes, c_quizzes], dim=0)
212         else:
213             input = w_quizzes
214             self.nb_batch_w_quizzes = w_quizzes.size(0)
215             self.nb_batch_c_quizzes = 0
216
217         # Shuffle
218         input = input[torch.randperm(input.size(0))]
219
220         if desc is None:
221             desc = f"epoch-{split}"
222         for batch in tqdm.tqdm(
223             input.split(self.batch_size), dynamic_ncols=True, desc=desc
224         ):
225             yield batch
226
227     def vocabulary_size(self):
228         return self.nb_token_values
229
230     def produce_results(
231         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
232     ):
233         def compute_accuracy(input):
234             input = input[:nmax]
235             ar_mask = self.make_ar_mask(input)
236             result = input.clone() * (1 - ar_mask)
237             seq_logproba = torch.empty(input.size(0), device=self.device)
238
239             masked_inplace_autoregression(
240                 model=model,
241                 batch_size=self.batch_size,
242                 input=result,
243                 ar_mask=ar_mask,
244                 seq_logproba=seq_logproba,
245                 temperature=1.0,
246                 deterministic_synthesis=deterministic_synthesis,
247                 progress_bar_desc=None,
248                 device=self.device,
249             )
250
251             nb_total, nb_correct = (
252                 input.size(0),
253                 (input == result).long().min(dim=1).values.sum(),
254             )
255
256             return nb_total, nb_correct
257
258         train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
259
260         self.logger(
261             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
262         )
263
264         test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes)
265
266         self.logger(
267             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
268         )
269
270         main_test_accuracy = test_nb_correct / test_nb_total
271         self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
272
273         ##############################
274
275         input = self.test_w_quizzes[:96]
276         ar_mask = self.make_ar_mask(input)
277         result = input.clone() * (1 - ar_mask)
278         seq_logproba = torch.empty(input.size(0), device=self.device)
279
280         masked_inplace_autoregression(
281             model=model,
282             batch_size=self.batch_size,
283             input=result,
284             ar_mask=ar_mask,
285             seq_logproba=seq_logproba,
286             temperature=1.0,
287             deterministic_synthesis=deterministic_synthesis,
288             progress_bar_desc=None,
289             device=self.device,
290         )
291
292         self.save_quizzes(
293             result_dir,
294             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
295             quizzes=result[:72],
296             prediction=True,
297         )
298
299         return main_test_accuracy
300
301     def renew_w_quizzes(self, nb, for_train=True):
302         input = self.train_w_quizzes if for_train else self.test_w_quizzes
303         nb = min(nb, input.size(0))
304         input[:-nb] = input[nb:].clone()
305         input[-nb:] = self.generate_token_sequences(nb).to(self.device)
306
307     def store_c_quizzes(self, new_c_quizzes, for_train=True):
308         if for_train:
309             self.train_c_quizzes.append(new_c_quizzes)
310         else:
311             self.test_c_quizzes.append(new_c_quizzes)
312
313     def reverse_time(self, c_quizzes):
314         l = (c_quizzes.size(1) - 1) // 2
315         direction = c_quizzes[:, 0:1]
316         direction = self.token_forward * (
317             direction == self.token_backward
318         ) + self.token_backward * (direction == self.token_forward)
319
320         return torch.cat(
321             [direction, c_quizzes[:, l + 1 :], c_quizzes[:, 1 : l + 1]], dim=1
322         )
323
324     def compute_correctness(
325         self,
326         c_quizzes,
327         models_for_validation,
328         bidirectional_validation=False,
329         deterministic_validation=True,
330     ):
331         reversed_c_quizzes = self.reverse_time(c_quizzes)
332
333         ar_mask = self.make_ar_mask(c_quizzes)
334         seq_logproba = torch.zeros(
335             c_quizzes.size(0),
336             max([m.id for m in models_for_validation]) + 1,
337             device=self.device,
338         )
339
340         # Check how many of models can solve the quizzes in both directions
341
342         nb_correct = 0
343
344         for model in models_for_validation:
345             result = c_quizzes.clone()
346
347             seq_logproba[...] = 0.0
348
349             masked_inplace_autoregression(
350                 model=model,
351                 batch_size=self.batch_size,
352                 input=result,
353                 ar_mask=ar_mask,
354                 seq_logproba=seq_logproba[:, model.id],
355                 temperature=1.0,
356                 deterministic_synthesis=deterministic_validation,
357                 # progress_bar_desc="solving c_quizzes",
358                 device=self.device,
359             )
360
361             correct = (c_quizzes == result).long().min(dim=-1).values
362
363             if bidirectional_validation:
364                 reversed_result = reversed_c_quizzes.clone()
365
366                 masked_inplace_autoregression(
367                     model=model,
368                     batch_size=self.batch_size,
369                     input=reversed_result,
370                     ar_mask=ar_mask,
371                     seq_logproba=seq_logproba[:, model.id],
372                     temperature=1.0,
373                     deterministic_synthesis=deterministic_validation,
374                     # progress_bar_desc="solving reversed c_quizzes",
375                     device=self.device,
376                 )
377
378                 reversed_correct = (
379                     (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
380                 )
381
382                 correct *= reversed_correct
383
384             # endif
385
386             nb_correct += correct
387
388         return nb_correct, seq_logproba
389
390     ###############################################################
391
392     def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
393         c_quizzes = torch.empty(
394             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
395         )
396
397         ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device)
398         ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1
399         ar_mask_second = 1 - ar_mask_first
400         ar_mask_first[:, 0] = 0
401         ar_mask_second[:, 0] = 0
402
403         seq_logproba = torch.zeros(ar_mask_first.size(0), device=self.device)
404
405         # First, we generate the answer at high temperature
406
407         c_quizzes[:, 0] = self.token_backward
408
409         masked_inplace_autoregression(
410             model=model_for_generation,
411             batch_size=self.batch_size,
412             input=c_quizzes,
413             ar_mask=ar_mask_first,
414             seq_logproba=seq_logproba,
415             temperature=temperature,
416             deterministic_synthesis=False,
417             device=self.device,
418         )
419
420         # Then, we generate the prompt at low temperature
421
422         masked_inplace_autoregression(
423             model=model_for_generation,
424             batch_size=self.batch_size,
425             input=c_quizzes,
426             ar_mask=ar_mask_second,
427             seq_logproba=seq_logproba,
428             temperature=1 / temperature,
429             deterministic_synthesis=False,
430             device=self.device,
431         )
432
433         # Then we return the quizz, and re-generate the response, now
434         # at low temperature
435
436         c_quizzes = self.reverse_time(c_quizzes)
437
438         masked_inplace_autoregression(
439             model=model_for_generation,
440             batch_size=self.batch_size,
441             input=c_quizzes,
442             ar_mask=ar_mask_second,
443             seq_logproba=seq_logproba,
444             temperature=1 / temperature,
445             deterministic_synthesis=False,
446             device=self.device,
447         )
448
449         return c_quizzes