ee06c25e1923ce1a69880384526b1a1636064a5f
[culture.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 ######################################################################
18
19
20 def masked_inplace_autoregression(
21     model,
22     batch_size,
23     input,
24     ar_mask,
25     summed_logits,
26     temperature,
27     deterministic_synthesis,
28     forbidden_tokens=None,
29     logit_biases=None,
30     progress_bar_desc="autoregression",
31     device=torch.device("cpu"),
32 ):
33     assert input.size() == ar_mask.size()
34
35     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
36
37     if progress_bar_desc is not None:
38         batches = tqdm.tqdm(
39             batches,
40             dynamic_ncols=True,
41             desc=progress_bar_desc,
42             total=(input.size(0) + batch_size - 1) // batch_size,
43         )
44
45     with torch.autograd.no_grad():
46         t = model.training
47         model.eval()
48
49         for input, ar_mask in batches:
50             model.masked_inplace_autoregression(
51                 input=input,
52                 ar_mask=ar_mask,
53                 summed_logits=summed_logits,
54                 temperature=temperature,
55                 deterministic_synthesis=deterministic_synthesis,
56                 forbidden_tokens=forbidden_tokens,
57                 forced_biases=logit_biases,
58             )
59
60         model.train(t)
61
62
63 ######################################################################
64
65
66 class Task:
67     def batches(self, split="train", nb_to_use=-1, desc=None):
68         pass
69
70     def vocabulary_size(self):
71         pass
72
73     def produce_results(
74         self, n_epoch, model, result_dir, logger, deterministic_synthesis
75     ):
76         pass
77
78
79 ######################################################################
80
81 import world
82
83
84 class World(Task):
85     def save_image(self, input, result_dir, filename, logger):
86         img = world.seq2img(input.to("cpu"), self.height, self.width)
87         image_name = os.path.join(result_dir, filename)
88         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
89         logger(f"wrote {image_name}")
90
91     def make_ar_mask(self, input):
92         b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
93         return b.long()[None, :].expand_as(input)
94
95     def __init__(
96         self,
97         nb_train_samples,
98         nb_test_samples,
99         batch_size,
100         result_dir=None,
101         logger=None,
102         device=torch.device("cpu"),
103     ):
104         super().__init__()
105
106         self.batch_size = batch_size
107         self.device = device
108         self.height = 6
109         self.width = 8
110
111         self.train_input = world.generate_seq(
112             nb_train_samples, height=self.height, width=self.width
113         ).to(device)
114
115         self.test_input = world.generate_seq(
116             nb_test_samples, height=self.height, width=self.width
117         ).to(device)
118
119         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
120
121         self.train_quizzes = []
122         self.test_quizzes = []
123
124         if result_dir is not None:
125             self.save_image(
126                 self.train_input[:72], result_dir, f"world_train.png", logger
127             )
128
129     def batches(self, split="train", desc=None):
130         assert split in {"train", "test"}
131         if split == "train":
132             input = self.train_input
133             quizzes = self.train_quizzes
134         else:
135             input = self.test_input
136             quizzes = self.test_quizzes
137
138         if len(quizzes) > 0:
139             quizzes = torch.cat(quizzes, dim=0)
140             if quizzes.size(0) > input.size(0) // 2:
141                 i = torch.randperm(input.size(0))[: input.size(0) // 2]
142                 quizzes = quizzes[i]
143
144             i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
145             input = input[i]
146
147             self.nb_batch_samples_world = input.size(0)
148             self.nb_batch_samples_quizzes = quizzes.size(0)
149
150             input = torch.cat([input, quizzes], dim=0)
151         else:
152             self.nb_batch_samples_world = input.size(0)
153             self.nb_batch_samples_quizzes = 0
154
155         # Shuffle
156         input = input[torch.randperm(input.size(0))]
157
158         if desc is None:
159             desc = f"epoch-{split}"
160         for batch in tqdm.tqdm(
161             input.split(self.batch_size), dynamic_ncols=True, desc=desc
162         ):
163             yield batch
164
165     def vocabulary_size(self):
166         return self.nb_codes
167
168     def produce_results(
169         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
170     ):
171         def compute_accuracy(input, logger=None):
172             input = input[:nmax]
173             ar_mask = self.make_ar_mask(input)
174             result = input.clone() * (1 - ar_mask)
175
176             masked_inplace_autoregression(
177                 model=model,
178                 batch_size=self.batch_size,
179                 input=result,
180                 ar_mask=ar_mask,
181                 summed_logits=None,
182                 temperature=1.0,
183                 deterministic_synthesis=deterministic_synthesis,
184                 progress_bar_desc=None,
185                 device=self.device,
186             )
187
188             nb_total, nb_correct = (
189                 input.size(0),
190                 (input == result).long().min(dim=1).values.sum(),
191             )
192
193             return nb_total, nb_correct
194
195         train_nb_total, train_nb_correct = compute_accuracy(self.train_input)
196
197         logger(
198             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
199         )
200
201         test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger)
202
203         logger(
204             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
205         )
206
207         main_test_accuracy = test_nb_correct / test_nb_total
208         logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
209
210         ##############################
211
212         input = self.test_input[:96]
213         ar_mask = self.make_ar_mask(input)
214         result = input.clone() * (1 - ar_mask)
215
216         masked_inplace_autoregression(
217             model=model,
218             batch_size=self.batch_size,
219             input=result,
220             ar_mask=ar_mask,
221             summed_logits=None,
222             temperature=1.0,
223             deterministic_synthesis=deterministic_synthesis,
224             progress_bar_desc=None,
225             device=self.device,
226         )
227
228         self.save_image(
229             result[:72],
230             result_dir,
231             f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
232             logger,
233         )
234
235         return main_test_accuracy
236
237     def renew_samples(self, nb, for_train=True):
238         input = self.train_input if for_train else self.test_input
239         nb = min(nb, input.size(0))
240         input[:-nb] = input[nb:].clone()
241         input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
242             self.device
243         )
244
245     def store_new_quizzes(self, new_quizzes, for_train=True):
246         if for_train:
247             self.train_quizzes.append(new_quizzes)
248         else:
249             self.test_quizzes.append(new_quizzes)
250
251     def create_new_quizzes(
252         self,
253         n_epoch,
254         result_dir,
255         logger,
256         nb,
257         model,
258         other_models,
259         desired_average_logits=None,
260     ):
261         ###############################################################
262         # Generate quizzes with model
263
264         quizzes = torch.empty(
265             nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
266         )
267
268         ar_mask = torch.full(quizzes.size(), 1, device=self.device)
269         summed_logits = torch.empty(nb, device=self.device)
270
271         temperature = 1
272         d_temperature = 1
273
274         while True:
275             summed_logits[...] = 0
276
277             masked_inplace_autoregression(
278                 model=model,
279                 batch_size=self.batch_size,
280                 input=quizzes,
281                 ar_mask=ar_mask,
282                 summed_logits=summed_logits,
283                 temperature=temperature,
284                 deterministic_synthesis=False,
285                 progress_bar_desc="creating quizzes",
286                 device=self.device,
287             )
288
289             average_logits = summed_logits.mean()
290
291             logger(f"{average_logits=} {desired_average_logits=}")
292
293             if desired_average_logits is None:
294                 break
295
296             # Oh man that's ugly
297             if average_logits < desired_average_logits * 1.1:
298                 if d_temperature > 0:
299                     d_temperature *= -0.5
300                 temperature += d_temperature
301             elif average_logits > desired_average_logits:
302                 if d_temperature < 0:
303                     d_temperature *= -0.5
304                 temperature += d_temperature
305             else:
306                 break
307
308             logger(f"chaging temperature to {temperature}")
309
310         ###############################################################
311         # Create the reverse quizzes
312
313         l = self.height * self.width
314         direction = quizzes[:, l : l + 1]
315         direction = world.token_forward * (
316             direction == world.token_backward
317         ) + world.token_backward * (direction == world.token_forward)
318         reverse_quizzes = torch.cat(
319             [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1
320         )
321
322         ar_mask = self.make_ar_mask(quizzes)
323
324         ###############################################################
325         # Check how many of the other models can solve them in both
326         # directions
327
328         nb_correct = []
329
330         for m in other_models:
331             result = quizzes.clone()
332
333             masked_inplace_autoregression(
334                 model=m,
335                 batch_size=self.batch_size,
336                 input=result,
337                 ar_mask=ar_mask,
338                 summed_logits=None,
339                 temperature=1.0,
340                 deterministic_synthesis=True,
341                 progress_bar_desc="solving quizzes",
342                 device=self.device,
343             )
344
345             correct = (quizzes == result).long().min(dim=-1).values
346
347             reverse_result = reverse_quizzes.clone()
348
349             masked_inplace_autoregression(
350                 model=m,
351                 batch_size=self.batch_size,
352                 input=reverse_result,
353                 ar_mask=ar_mask,
354                 summed_logits=None,
355                 temperature=1.0,
356                 deterministic_synthesis=True,
357                 progress_bar_desc="solving reversed quizzes",
358                 device=self.device,
359             )
360
361             reverse_correct = (
362                 (reverse_quizzes == reverse_result).long().min(dim=-1).values
363             )
364
365             nb_correct.append((correct * reverse_correct)[None, :])
366
367         nb_correct = torch.cat(nb_correct, dim=0)
368
369         # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
370         # with open(filename, "w") as f:
371         # for k in nb_correct:
372         # f.write(f"{k}\n")
373
374         return quizzes, nb_correct.sum(dim=0), summed_logits.mean()