153317c1536cd0a90662ebf534b20f825befba9e
[culture.git] / quizz_machine.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 import mygpt
16 from mygpt import BracketedSequence
17
18 ######################################################################
19
20 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
21 # 1s where tokens should be generated. The others are kept
22 # unchanged.
23
24
25 def one_batch_masked_inplace_autoregression(
26     model,
27     input,
28     ar_mask,
29     seq_logproba,
30     temperature=1.0,
31     deterministic_synthesis=False,
32 ):
33     to_generate = (ar_mask.sum(0) > 0).nonzero()
34
35     if to_generate.min() > 0:
36         model(
37             BracketedSequence(input, 0, to_generate.min())
38         )  # Needed to initialize the model's cache
39     for s in range(to_generate.min(), to_generate.max() + 1):
40         output = model(BracketedSequence(input, s, 1)).x
41
42         logits = output[:, s]
43
44         logits = (logits / temperature).log_softmax(dim=-1)
45
46         if deterministic_synthesis:
47             t_next = logits.argmax(-1)
48         else:
49             dist = torch.distributions.categorical.Categorical(logits=logits)
50             t_next = dist.sample()
51
52         all_n = torch.arange(t_next.size(0))
53         seq_logproba += logits[all_n, t_next].sum(dim=-1)
54
55         input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
56
57
58 def masked_inplace_autoregression(
59     model,
60     batch_size,
61     input,
62     ar_mask,
63     seq_logproba,
64     temperature,
65     deterministic_synthesis,
66     forbidden_tokens=None,
67     logit_biases=None,
68     progress_bar_desc=None,
69     device=torch.device("cpu"),
70 ):
71     assert input.size() == ar_mask.size()
72
73     batches = zip(
74         input.split(batch_size),
75         ar_mask.split(batch_size),
76         seq_logproba.split(batch_size),
77     )
78
79     if progress_bar_desc is not None:
80         batches = tqdm.tqdm(
81             batches,
82             dynamic_ncols=True,
83             desc=progress_bar_desc,
84             total=(input.size(0) + batch_size - 1) // batch_size,
85         )
86
87     with torch.autograd.no_grad():
88         t = model.training
89         model.eval()
90
91         for input, ar_mask, seq_logproba in batches:
92             one_batch_masked_inplace_autoregression(
93                 model=model,
94                 input=input,
95                 ar_mask=ar_mask,
96                 seq_logproba=seq_logproba,
97                 temperature=temperature,
98                 deterministic_synthesis=deterministic_synthesis,
99             )
100
101         model.train(t)
102
103
104 ######################################################################
105
106
107 class QuizzMachine:
108     def make_ar_mask(self, input, first, nb):
109         i = torch.arange(input.size(1), device=input.device)
110         b = torch.logical_and(i >= first, i < first + nb)
111         return b.long()[None, :].expand_as(input)
112
113     def generate_token_sequences(self, nb):
114         prompts, answers = self.problem.generate_prompts_and_answers(nb)
115
116         if self.prompt_len is None:
117             self.prompt_len = prompts.size(1)
118
119         if self.answer_len is None:
120             self.answer_len = answers.size(1)
121
122         assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len
123
124         result = []
125
126         for prompt, answer in zip(prompts, answers):
127             if torch.rand(1) < 0.5:
128                 a = [
129                     torch.tensor([self.token_forward]),
130                     prompt,
131                     torch.tensor([self.token_forward]),
132                     answer,
133                 ]
134             else:
135                 a = [
136                     torch.tensor([self.token_backward]),
137                     answer,
138                     torch.tensor([self.token_backward]),
139                     prompt,
140                 ]
141
142             result.append(torch.cat(a, dim=0)[None, :])
143
144         return torch.cat(result, dim=0)
145
146     def __init__(
147         self,
148         problem,
149         nb_train_samples,
150         nb_test_samples,
151         batch_size,
152         result_dir,
153         logger,
154         device=torch.device("cpu"),
155     ):
156         super().__init__()
157
158         v = problem.nb_token_values()
159         self.token_forward = v
160         self.token_backward = v + 1
161         self.nb_token_values = v + 2
162
163         self.problem = problem
164         self.batch_size = batch_size
165         self.device = device
166         self.logger = logger
167         self.prompt_len = None
168         self.answer_len = None
169
170         self.train_w_quizzes = self.generate_token_sequences(nb_train_samples).to(
171             device
172         )
173
174         self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
175
176         self.train_c_quizzes = []
177         self.test_c_quizzes = []
178
179         if result_dir is not None:
180             self.save_quizzes(
181                 result_dir, "culture_w_quizzes", self.train_w_quizzes[:72]
182             )
183
184     def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False):
185         l = (quizzes.size(1) - 1) // 2
186         forward = (quizzes[:, 0] == self.token_forward).long()
187         backward = (quizzes[:, 0] == self.token_backward).long()
188         assert forward.equal(1 - backward)
189         first = quizzes[:, 1 : 1 + l]
190         second = quizzes[:, 1 + l : 1 + 2 * l]
191         prompts = forward[:, None] * first + backward[:, None] * second
192         answers = forward[:, None] * second + backward[:, None] * first
193
194         if prediction:
195             predicted_prompts = backward
196             predicted_answers = forward
197         else:
198             predicted_prompts = None
199             predicted_answers = None
200
201         self.problem.save_quizzes(
202             result_dir,
203             filename_prefix,
204             prompts,
205             answers,
206             predicted_prompts,
207             predicted_answers,
208         )
209
210     def batches(self, split="train", desc=None):
211         assert split in {"train", "test"}
212         if split == "train":
213             w_quizzes = self.train_w_quizzes
214             c_quizzes = self.train_c_quizzes
215         else:
216             w_quizzes = self.test_w_quizzes
217             c_quizzes = self.test_c_quizzes
218
219         if len(c_quizzes) > 0:
220             c_quizzes = torch.cat(c_quizzes, dim=0)
221             if c_quizzes.size(0) > w_quizzes.size(0) // 2:
222                 i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
223                 c_quizzes = c_quizzes[i]
224
225             i = torch.randperm(w_quizzes.size(0))[
226                 : w_quizzes.size(0) - c_quizzes.size(0)
227             ]
228             w_quizzes = w_quizzes[i]
229
230             self.nb_batch_w_quizzes = w_quizzes.size(0)
231             self.nb_batch_c_quizzes = c_quizzes.size(0)
232
233             input = torch.cat([w_quizzes, c_quizzes], dim=0)
234         else:
235             input = w_quizzes
236             self.nb_batch_w_quizzes = w_quizzes.size(0)
237             self.nb_batch_c_quizzes = 0
238
239         # Shuffle
240         input = input[torch.randperm(input.size(0))]
241
242         if desc is None:
243             desc = f"epoch-{split}"
244         for batch in tqdm.tqdm(
245             input.split(self.batch_size), dynamic_ncols=True, desc=desc
246         ):
247             yield batch
248
249     def vocabulary_size(self):
250         return self.nb_token_values
251
252     def produce_results(
253         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
254     ):
255         def compute_accuracy(input):
256             input = input[:nmax]
257             ar_mask = self.make_ar_mask(input, 2 + self.prompt_len, self.answer_len)
258             result = input.clone() * (1 - ar_mask)
259             seq_logproba = torch.empty(input.size(0), device=self.device)
260
261             masked_inplace_autoregression(
262                 model=model,
263                 batch_size=self.batch_size,
264                 input=result,
265                 ar_mask=ar_mask,
266                 seq_logproba=seq_logproba,
267                 temperature=1.0,
268                 deterministic_synthesis=deterministic_synthesis,
269                 progress_bar_desc=None,
270                 device=self.device,
271             )
272
273             nb_total = input.size(0)
274             nb_correct = (input == result).long().min(dim=1).values.sum()
275
276             return nb_total, nb_correct
277
278         train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
279
280         self.logger(
281             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
282         )
283
284         test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes)
285
286         self.logger(
287             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
288         )
289
290         main_test_accuracy = test_nb_correct / test_nb_total
291         self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
292
293         ##############################
294
295         input = self.test_w_quizzes[:96]
296         ar_mask = self.make_ar_mask(input, 2 + self.prompt_len, self.answer_len)
297         result = input.clone() * (1 - ar_mask)
298         seq_logproba = torch.empty(input.size(0), device=self.device)
299
300         masked_inplace_autoregression(
301             model=model,
302             batch_size=self.batch_size,
303             input=result,
304             ar_mask=ar_mask,
305             seq_logproba=seq_logproba,
306             temperature=1.0,
307             deterministic_synthesis=deterministic_synthesis,
308             progress_bar_desc=None,
309             device=self.device,
310         )
311
312         self.save_quizzes(
313             result_dir,
314             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
315             quizzes=result[:72],
316             prediction=True,
317         )
318
319         return main_test_accuracy
320
321     def renew_w_quizzes(self, nb, for_train=True):
322         input = self.train_w_quizzes if for_train else self.test_w_quizzes
323         nb = min(nb, input.size(0))
324         input[:-nb] = input[nb:].clone()
325         input[-nb:] = self.generate_token_sequences(nb).to(self.device)
326
327     def store_c_quizzes(self, new_c_quizzes, for_train=True):
328         if for_train:
329             self.train_c_quizzes.append(new_c_quizzes)
330         else:
331             self.test_c_quizzes.append(new_c_quizzes)
332
333     def forward_to_backward(self, c_quizzes):
334         prompts = c_quizzes[:, 1 : 1 + self.prompt_len]
335         answers = c_quizzes[:, 2 + self.prompt_len :]
336         return torch.cat(
337             [
338                 c_quizzes.new_full((c_quizzes, 1), self.token_backward),
339                 answers,
340                 c_quizzes.new_full((c_quizzes, 1), self.token_backward),
341                 prompts,
342             ],
343             dim=1,
344         )
345
346     def backward_to_forward(self, c_quizzes):
347         answers = c_quizzes[:, 1 : 1 + self.answer_len :]
348         prompts = c_quizzes[:, 2 + self.answer_len :]
349         return torch.cat(
350             [
351                 c_quizzes.new_full((c_quizzes.size(0), 1), self.token_forward),
352                 prompts,
353                 c_quizzes.new_full((c_quizzes.size(0), 1), self.token_forward),
354                 answers,
355             ],
356             dim=1,
357         )
358
359     def compute_correctness(
360         self,
361         c_quizzes,
362         models_for_validation,
363         bidirectional_validation=False,
364         deterministic_validation=True,
365     ):
366         if bidirectional_validation:
367             backward_c_quizzes = self.forward_to_backward(c_quizzes)
368
369         seq_logproba = torch.zeros(
370             c_quizzes.size(0),
371             max([m.id for m in models_for_validation]) + 1,
372             device=self.device,
373         )
374
375         nb_correct = 0
376
377         for model in models_for_validation:
378             result = c_quizzes.clone()
379
380             seq_logproba[...] = 0.0
381
382             ar_mask = self.make_ar_mask(result, 2 + self.prompt_len, self.answer_len)
383
384             masked_inplace_autoregression(
385                 model=model,
386                 batch_size=self.batch_size,
387                 input=result,
388                 ar_mask=ar_mask,
389                 seq_logproba=seq_logproba[:, model.id],
390                 temperature=1.0,
391                 deterministic_synthesis=deterministic_validation,
392                 # progress_bar_desc="solving c_quizzes",
393                 device=self.device,
394             )
395
396             correct = (c_quizzes == result).long().min(dim=-1).values
397
398             if bidirectional_validation:
399                 backward_result = backward_c_quizzes.clone()
400
401                 ar_mask = self.make_ar_mask(
402                     backward_result, 2 + self.answer_len, self.prompt_len
403                 )
404
405                 masked_inplace_autoregression(
406                     model=model,
407                     batch_size=self.batch_size,
408                     input=backward_result,
409                     ar_mask=ar_mask,
410                     seq_logproba=seq_logproba[:, model.id],
411                     temperature=1.0,
412                     deterministic_synthesis=deterministic_validation,
413                     # progress_bar_desc="solving backward c_quizzes",
414                     device=self.device,
415                 )
416
417                 backward_correct = (
418                     (backward_c_quizzes == backward_result).long().min(dim=-1).values
419                 )
420
421                 correct *= backward_correct
422
423             # endif
424
425             nb_correct += correct
426
427         return nb_correct, seq_logproba
428
429     ###############################################################
430
431     def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
432         c_quizzes = torch.empty(
433             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
434         )
435
436         ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device)
437         ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1
438         ar_mask_second = 1 - ar_mask_first
439         ar_mask_first[:, 0] = 0
440         ar_mask_second[:, 0] = 0
441
442         seq_logproba = torch.zeros(ar_mask_first.size(0), device=self.device)
443
444         # First, we generate the answer at high temperature
445
446         c_quizzes[:, 0] = self.token_backward
447
448         masked_inplace_autoregression(
449             model=model_for_generation,
450             batch_size=self.batch_size,
451             input=c_quizzes,
452             ar_mask=ar_mask_first,
453             seq_logproba=seq_logproba,
454             temperature=temperature,
455             deterministic_synthesis=False,
456             device=self.device,
457         )
458
459         # Then, we generate the prompt at low temperature
460
461         masked_inplace_autoregression(
462             model=model_for_generation,
463             batch_size=self.batch_size,
464             input=c_quizzes,
465             ar_mask=ar_mask_second,
466             seq_logproba=seq_logproba,
467             temperature=1 / temperature,
468             deterministic_synthesis=False,
469             device=self.device,
470         )
471
472         # Then we return the quizz, and re-generate the response, now
473         # at low temperature
474
475         c_quizzes = self.backward_to_forward(c_quizzes)
476
477         masked_inplace_autoregression(
478             model=model_for_generation,
479             batch_size=self.batch_size,
480             input=c_quizzes,
481             ar_mask=ar_mask_second,
482             seq_logproba=seq_logproba,
483             temperature=1 / temperature,
484             deterministic_synthesis=False,
485             device=self.device,
486         )
487
488         return c_quizzes