7b0b877bdc0b04b0eec304850630461878493b96
[culture.git] / quizz_machine.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 import mygpt
16 from mygpt import BracketedSequence
17
18 ######################################################################
19
20
21 class Gang(nn.Module):
22     def __init__(self, models, nb_models_for_generation, mode="groupthink"):
23         super().__init__()
24         self.models = nn.ModuleList(models)
25         self.nb_models_for_generation = nb_models_for_generation
26         self.mode = mode
27
28     def forward(self, bs):
29         # If first = 0, we are re-starting an auto-regressive process,
30         # that's the right moment to randomize who gonna do it
31         if bs.first == 0:
32             self.models_to_use = [
33                 self.models[k]
34                 for k in torch.randperm(len(self.models))[
35                     : self.nb_models_for_generation
36                 ]
37             ]
38
39         all_the_logits = torch.cat(
40             [model(bs).x[None] for model in self.models_to_use], dim=0
41         )
42
43         if self.mode == "groupthink":
44             y = all_the_logits.mean(dim=0)
45         elif self.mode == "groupwork":
46             m = torch.rand(all_the_logits.size(), device=all_the_logits.device)
47             m = (m.sort(dim=0).indices == 0).long()
48             y = (y * m).sum(dim=0)
49         else:
50             raise ValueError(f"Invalid mode {self.mode}")
51
52         return BracketedSequence(y, bs.first, bs.nb)
53
54
55 ######################################################################
56
57 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
58 # 1s where tokens should be generated. The others are kept
59 # unchanged.
60
61
62 def one_batch_masked_inplace_autoregression(
63     model,
64     input,
65     ar_mask,
66     seq_logproba,
67     temperature=1.0,
68     deterministic_synthesis=False,
69     forbidden_tokens=None,
70     forced_biases=None,
71 ):
72     to_generate = (ar_mask.sum(0) > 0).nonzero()
73
74     if to_generate.min() > 0:
75         model(
76             BracketedSequence(input, 0, to_generate.min())
77         )  # Needed to initialize the model's cache
78     for s in range(to_generate.min(), to_generate.max() + 1):
79         output = model(BracketedSequence(input, s, 1)).x
80
81         logits = output[:, s]
82
83         logits = (logits / temperature).log_softmax(dim=-1)
84
85         if forbidden_tokens is not None:
86             logits = logits.masked_fill(forbidden_tokens, float("-inf"))
87
88         if forced_biases is not None:
89             logits = logits + forced_biases[None, :]
90
91         if deterministic_synthesis:
92             t_next = logits.argmax(-1)
93         else:
94             dist = torch.distributions.categorical.Categorical(logits=logits)
95             t_next = dist.sample()
96
97         all_n = torch.arange(t_next.size(0))
98         seq_logproba += logits[all_n, t_next].sum(dim=-1)
99
100         input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
101
102
103 def masked_inplace_autoregression(
104     model,
105     batch_size,
106     input,
107     ar_mask,
108     seq_logproba,
109     temperature,
110     deterministic_synthesis,
111     forbidden_tokens=None,
112     logit_biases=None,
113     progress_bar_desc=None,
114     device=torch.device("cpu"),
115 ):
116     assert input.size() == ar_mask.size()
117
118     batches = zip(
119         input.split(batch_size),
120         ar_mask.split(batch_size),
121         seq_logproba.split(batch_size),
122     )
123
124     if progress_bar_desc is not None:
125         batches = tqdm.tqdm(
126             batches,
127             dynamic_ncols=True,
128             desc=progress_bar_desc,
129             total=(input.size(0) + batch_size - 1) // batch_size,
130         )
131
132     with torch.autograd.no_grad():
133         t = model.training
134         model.eval()
135
136         for input, ar_mask, seq_logproba in batches:
137             one_batch_masked_inplace_autoregression(
138                 model=model,
139                 input=input,
140                 ar_mask=ar_mask,
141                 seq_logproba=seq_logproba,
142                 temperature=temperature,
143                 deterministic_synthesis=deterministic_synthesis,
144                 forbidden_tokens=forbidden_tokens,
145                 forced_biases=logit_biases,
146             )
147
148         model.train(t)
149
150
151 ######################################################################
152
153
154 class QuizzMachine:
155     def make_ar_mask(self, input):
156         b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
157         return b.long()[None, :].expand_as(input)
158
159     def __init__(
160         self,
161         problem,
162         nb_train_samples,
163         nb_test_samples,
164         batch_size,
165         result_dir,
166         logger,
167         device=torch.device("cpu"),
168     ):
169         super().__init__()
170
171         self.problem = problem
172         self.batch_size = batch_size
173         self.device = device
174         self.logger = logger
175
176         self.train_w_quizzes = self.problem.generate_token_sequences(
177             nb_train_samples
178         ).to(device)
179         self.test_w_quizzes = self.problem.generate_token_sequences(nb_test_samples).to(
180             device
181         )
182
183         self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1
184
185         self.train_c_quizzes = []
186         self.test_c_quizzes = []
187
188         if result_dir is not None:
189             self.problem.save_quizzes(
190                 self.train_w_quizzes[:72], result_dir, "culture_w_quizzes"
191             )
192
193     def batches(self, split="train", desc=None):
194         assert split in {"train", "test"}
195         if split == "train":
196             w_quizzes = self.train_w_quizzes
197             c_quizzes = self.train_c_quizzes
198         else:
199             w_quizzes = self.test_w_quizzes
200             c_quizzes = self.test_c_quizzes
201
202         if len(c_quizzes) > 0:
203             c_quizzes = torch.cat(c_quizzes, dim=0)
204             if c_quizzes.size(0) > w_quizzes.size(0) // 2:
205                 i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
206                 c_quizzes = c_quizzes[i]
207
208             i = torch.randperm(w_quizzes.size(0))[
209                 : w_quizzes.size(0) - c_quizzes.size(0)
210             ]
211             w_quizzes = w_quizzes[i]
212
213             self.nb_batch_w_quizzes = w_quizzes.size(0)
214             self.nb_batch_c_quizzes = c_quizzes.size(0)
215
216             input = torch.cat([w_quizzes, c_quizzes], dim=0)
217         else:
218             input = w_quizzes
219             self.nb_batch_w_quizzes = w_quizzes.size(0)
220             self.nb_batch_c_quizzes = 0
221
222         # Shuffle
223         input = input[torch.randperm(input.size(0))]
224
225         if desc is None:
226             desc = f"epoch-{split}"
227         for batch in tqdm.tqdm(
228             input.split(self.batch_size), dynamic_ncols=True, desc=desc
229         ):
230             yield batch
231
232     def vocabulary_size(self):
233         return self.nb_codes
234
235     def produce_results(
236         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
237     ):
238         def compute_accuracy(input):
239             input = input[:nmax]
240             ar_mask = self.make_ar_mask(input)
241             result = input.clone() * (1 - ar_mask)
242             seq_logproba = torch.empty(input.size(0), device=self.device)
243
244             masked_inplace_autoregression(
245                 model=model,
246                 batch_size=self.batch_size,
247                 input=result,
248                 ar_mask=ar_mask,
249                 seq_logproba=seq_logproba,
250                 temperature=1.0,
251                 deterministic_synthesis=deterministic_synthesis,
252                 progress_bar_desc=None,
253                 device=self.device,
254             )
255
256             nb_total, nb_correct = (
257                 input.size(0),
258                 (input == result).long().min(dim=1).values.sum(),
259             )
260
261             return nb_total, nb_correct
262
263         train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
264
265         self.logger(
266             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
267         )
268
269         test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes)
270
271         self.logger(
272             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
273         )
274
275         main_test_accuracy = test_nb_correct / test_nb_total
276         self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
277
278         ##############################
279
280         input = self.test_w_quizzes[:96]
281         ar_mask = self.make_ar_mask(input)
282         result = input.clone() * (1 - ar_mask)
283         seq_logproba = torch.empty(input.size(0), device=self.device)
284
285         masked_inplace_autoregression(
286             model=model,
287             batch_size=self.batch_size,
288             input=result,
289             ar_mask=ar_mask,
290             seq_logproba=seq_logproba,
291             temperature=1.0,
292             deterministic_synthesis=deterministic_synthesis,
293             progress_bar_desc=None,
294             device=self.device,
295         )
296
297         self.problem.save_quizzes(
298             result[:72], result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}"
299         )
300
301         return main_test_accuracy
302
303     def renew_w_quizzes(self, nb, for_train=True):
304         input = self.train_w_quizzes if for_train else self.test_w_quizzes
305         nb = min(nb, input.size(0))
306         input[:-nb] = input[nb:].clone()
307         input[-nb:] = self.problem.generate_token_sequences(nb).to(self.device)
308
309     def store_c_quizzes(self, new_c_quizzes, for_train=True):
310         if for_train:
311             self.train_c_quizzes.append(new_c_quizzes)
312         else:
313             self.test_c_quizzes.append(new_c_quizzes)
314
315     def reverse_time(self, c_quizzes):
316         token_forward, token_backward = self.problem.direction_tokens()
317
318         l = (c_quizzes.size(1) - 1) // 2
319         direction = c_quizzes[:, l : l + 1]
320         direction = self.problem.token_forward * (
321             direction == self.problem.token_backward
322         ) + self.problem.token_backward * (direction == self.problem.token_forward)
323
324         return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1)
325
326     def comput_correctness(self, c_quizzes, models_for_validation):
327         reversed_c_quizzes = self.reverse_time(c_quizzes)
328
329         ar_mask = self.make_ar_mask(c_quizzes)
330         seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
331
332         # Check how many of models can solve the quizzes in both directions
333
334         nb_correct = 0
335
336         for model in models_for_validation:
337             result = c_quizzes.clone()
338
339             masked_inplace_autoregression(
340                 model=model,
341                 batch_size=self.batch_size,
342                 input=result,
343                 ar_mask=ar_mask,
344                 seq_logproba=seq_logproba,
345                 temperature=1.0,
346                 deterministic_synthesis=True,
347                 # progress_bar_desc="solving c_quizzes",
348                 device=self.device,
349             )
350
351             correct = (c_quizzes == result).long().min(dim=-1).values
352
353             reversed_result = reversed_c_quizzes.clone()
354
355             masked_inplace_autoregression(
356                 model=model,
357                 batch_size=self.batch_size,
358                 input=reversed_result,
359                 ar_mask=ar_mask,
360                 seq_logproba=seq_logproba,
361                 temperature=1.0,
362                 deterministic_synthesis=True,
363                 # progress_bar_desc="solving reversed c_quizzes",
364                 device=self.device,
365             )
366
367             reversed_correct = (
368                 (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
369             )
370
371             nb_correct += correct * reversed_correct
372
373         return nb_correct
374
375     ###############################################################
376
377     def generate_quizzes(
378         self, nb, model_for_generation, min_ave_seq_logproba, reverse_cleanup=False
379     ):
380         c_quizzes = torch.empty(
381             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
382         )
383
384         ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device)
385         ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1
386         ar_mask_solve = 1 - ar_mask_prompt
387         seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
388
389         # warnings.warn("noise injection", RuntimeWarning)
390         temperature = 1
391         # noise_std = torch.rand(1).item()
392         # self.logger(f"{noise_std=}")
393
394         # mygpt.set_noise_injection(model_for_generation, noise_std)
395
396         masked_inplace_autoregression(
397             model=model_for_generation,
398             batch_size=self.batch_size,
399             input=c_quizzes,
400             ar_mask=ar_mask_prompt,
401             seq_logproba=seq_logproba,
402             temperature=temperature,
403             deterministic_synthesis=False,
404             # progress_bar_desc="sampling c_quizzes",
405             device=self.device,
406         )
407
408         # mygpt.set_noise_injection(model_for_generation, 0.0)
409
410         ave_seq_logproba = seq_logproba.mean()
411
412         masked_inplace_autoregression(
413             model=model_for_generation,
414             batch_size=self.batch_size,
415             input=c_quizzes,
416             ar_mask=ar_mask_solve,
417             seq_logproba=seq_logproba,
418             temperature=temperature,
419             deterministic_synthesis=True,
420             # progress_bar_desc="sampling c_quizzes",
421             device=self.device,
422         )
423
424         if reverse_cleanup:
425             c_quizzes = self.reverse_time(c_quizzes)
426             masked_inplace_autoregression(
427                 model=model_for_generation,
428                 batch_size=self.batch_size,
429                 input=c_quizzes,
430                 ar_mask=ar_mask_solve,
431                 seq_logproba=seq_logproba,
432                 temperature=temperature,
433                 deterministic_synthesis=True,
434                 # progress_bar_desc="sampling c_quizzes",
435                 device=self.device,
436             )
437
438         return c_quizzes, seq_logproba.mean()
439
440     ######################################################################
441
442     def create_c_quizzes(
443         self,
444         nb,
445         model_for_generation,
446         models_for_validation,
447         min_ave_seq_logproba,
448         n_epoch,
449         result_dir,
450     ):
451         c_quizzes, ave_seq_logproba = self.generate_quizzes(
452             nb, model_for_generation, min_ave_seq_logproba
453         )
454
455         nb_correct = self.comput_correctness(c_quizzes, models_for_validation)
456
457         return c_quizzes, nb_correct, ave_seq_logproba
458
459     ######################################################################
460
461     def gang_create_c_quizzes(
462         self,
463         nb,
464         nb_models_for_generation,
465         models,
466         mode,
467         min_ave_seq_logproba,
468         n_epoch,
469         result_dir,
470     ):
471         model_for_generation = Gang(models, nb_models_for_generation, mode)
472         models_for_validation = models
473         return self.create_c_quizzes(
474             nb,
475             model_for_generation,
476             models_for_validation,
477             min_ave_seq_logproba,
478             n_epoch,
479             result_dir,
480         )