b967465de6c23a94b9ee1af8a672337167903ac8
[culture.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 ######################################################################
18
19
20 def masked_inplace_autoregression(
21     model,
22     batch_size,
23     input,
24     ar_mask,
25     temperature,
26     deterministic_synthesis,
27     forbidden_tokens=None,
28     logit_biases=None,
29     progress_bar_desc="autoregression",
30     device=torch.device("cpu"),
31 ):
32     assert input.size() == ar_mask.size()
33
34     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
35
36     if progress_bar_desc is not None:
37         batches = tqdm.tqdm(
38             batches,
39             dynamic_ncols=True,
40             desc=progress_bar_desc,
41             total=(input.size(0) + batch_size - 1) // batch_size,
42         )
43
44     with torch.autograd.no_grad():
45         t = model.training
46         model.eval()
47
48         sum_logits = 0
49
50         for input, ar_mask in batches:
51             sum_logits += model.masked_inplace_autoregression(
52                 input=input,
53                 ar_mask=ar_mask,
54                 temperature=temperature,
55                 deterministic_synthesis=deterministic_synthesis,
56                 forbidden_tokens=forbidden_tokens,
57                 forced_biases=logit_biases,
58             )
59
60         model.train(t)
61
62         return sum_logits
63
64
65 ######################################################################
66
67
68 class Task:
69     def batches(self, split="train", nb_to_use=-1, desc=None):
70         pass
71
72     def vocabulary_size(self):
73         pass
74
75     def produce_results(
76         self, n_epoch, model, result_dir, logger, deterministic_synthesis
77     ):
78         pass
79
80
81 ######################################################################
82
83 import world
84
85
86 class World(Task):
87     def save_image(self, input, result_dir, filename, logger):
88         img = world.seq2img(input.to("cpu"), self.height, self.width)
89         image_name = os.path.join(result_dir, filename)
90         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
91         logger(f"wrote {image_name}")
92
93     def make_ar_mask(self, input):
94         b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
95         return b.long()[None, :].expand_as(input)
96
97     def __init__(
98         self,
99         nb_train_samples,
100         nb_test_samples,
101         batch_size,
102         result_dir=None,
103         logger=None,
104         device=torch.device("cpu"),
105     ):
106         super().__init__()
107
108         self.batch_size = batch_size
109         self.device = device
110         self.height = 6
111         self.width = 8
112
113         self.train_input = world.generate_seq(
114             nb_train_samples, height=self.height, width=self.width
115         ).to(device)
116
117         self.test_input = world.generate_seq(
118             nb_test_samples, height=self.height, width=self.width
119         ).to(device)
120
121         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
122
123         self.train_quizzes = []
124         self.test_quizzes = []
125
126         if result_dir is not None:
127             self.save_image(
128                 self.train_input[:72], result_dir, f"world_train.png", logger
129             )
130
131     def batches(self, split="train", desc=None):
132         assert split in {"train", "test"}
133         if split == "train":
134             input = self.train_input
135             quizzes = self.train_quizzes
136         else:
137             input = self.test_input
138             quizzes = self.test_quizzes
139
140         if len(quizzes) > 0:
141             quizzes = torch.cat(quizzes, dim=0)
142             if quizzes.size(0) > input.size(0) // 2:
143                 i = torch.randperm(input.size(0))[: input.size(0) // 2]
144                 quizzes = quizzes[i]
145
146             i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
147             input = input[i]
148
149             self.nb_batch_samples_world = input.size(0)
150             self.nb_batch_samples_quizzes = quizzes.size(0)
151
152             input = torch.cat([input, quizzes], dim=0)
153         else:
154             self.nb_batch_samples_world = input.size(0)
155             self.nb_batch_samples_quizzes = 0
156
157         if desc is None:
158             desc = f"epoch-{split}"
159         for batch in tqdm.tqdm(
160             input.split(self.batch_size), dynamic_ncols=True, desc=desc
161         ):
162             yield batch
163
164     def vocabulary_size(self):
165         return self.nb_codes
166
167     def produce_results(
168         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
169     ):
170         def compute_accuracy(input, logger=None):
171             input = input[:nmax]
172             ar_mask = self.make_ar_mask(input)
173             result = input.clone() * (1 - ar_mask)
174
175             masked_inplace_autoregression(
176                 model=model,
177                 batch_size=self.batch_size,
178                 input=result,
179                 ar_mask=ar_mask,
180                 temperature=1.0,
181                 deterministic_synthesis=deterministic_synthesis,
182                 progress_bar_desc=None,
183                 device=self.device,
184             )
185
186             nb_total, nb_correct = (
187                 input.size(0),
188                 (input == result).long().min(dim=1).values.sum(),
189             )
190
191             return nb_total, nb_correct
192
193         train_nb_total, train_nb_correct = compute_accuracy(self.train_input)
194
195         logger(
196             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
197         )
198
199         test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger)
200
201         logger(
202             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
203         )
204
205         main_test_accuracy = test_nb_correct / test_nb_total
206         logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
207
208         ##############################
209
210         input = self.test_input[:96]
211         ar_mask = self.make_ar_mask(input)
212         result = input.clone() * (1 - ar_mask)
213
214         masked_inplace_autoregression(
215             model=model,
216             batch_size=self.batch_size,
217             input=result,
218             ar_mask=ar_mask,
219             temperature=1.0,
220             deterministic_synthesis=deterministic_synthesis,
221             progress_bar_desc=None,
222             device=self.device,
223         )
224
225         self.save_image(
226             result[:72],
227             result_dir,
228             f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
229             logger,
230         )
231
232         return main_test_accuracy
233
234     def renew_samples(self, nb, for_train=True):
235         input = self.train_input if for_train else self.test_input
236         nb = min(nb, input.size(0))
237         input[:-nb] = input[nb:].clone()
238         input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
239             self.device
240         )
241
242     def store_new_quizzes(self, new_quizzes, for_train=True):
243         if for_train:
244             self.train_quizzes.append(new_quizzes)
245         else:
246             self.test_quizzes.append(new_quizzes)
247
248     def create_new_quizzes(
249         self,
250         n_epoch,
251         result_dir,
252         logger,
253         nb,
254         model,
255         other_models,
256         desired_average_logits=None,
257     ):
258         ###############################################################
259         # Generate quizzes with model
260
261         quizzes = torch.empty(
262             nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
263         )
264         ar_mask = torch.full(quizzes.size(), 1, device=self.device)
265
266         sum_logits = masked_inplace_autoregression(
267             model=model,
268             batch_size=self.batch_size,
269             input=quizzes,
270             ar_mask=ar_mask,
271             temperature=1.0,
272             deterministic_synthesis=False,
273             progress_bar_desc="creating quizzes",
274             device=self.device,
275         )
276
277         average_logits = sum_logits / quizzes.numel()
278
279         # It's a bit brutal to do it twice, we should probably have a
280         # moving average and apply it right away
281
282         if desired_average_logits is not None:
283             temperature = average_logits / desired_average_logits
284             masked_inplace_autoregression(
285                 model=model,
286                 batch_size=self.batch_size,
287                 input=quizzes,
288                 ar_mask=ar_mask,
289                 temperature=temperature,
290                 deterministic_synthesis=False,
291                 progress_bar_desc="creating quizzes",
292                 device=self.device,
293             )
294
295         ###############################################################
296         # Create the reverse quizzes
297
298         l = self.height * self.width
299         direction = quizzes[:, l : l + 1]
300         direction = world.token_forward * (
301             direction == world.token_backward
302         ) + world.token_backward * (direction == world.token_forward)
303         reverse_quizzes = torch.cat(
304             [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1
305         )
306
307         ar_mask = self.make_ar_mask(quizzes)
308
309         ###############################################################
310         # Check how many of the other models can solve them in both
311         # directions
312
313         nb_correct = []
314
315         for m in other_models:
316             result = quizzes.clone()
317
318             masked_inplace_autoregression(
319                 model=m,
320                 batch_size=self.batch_size,
321                 input=result,
322                 ar_mask=ar_mask,
323                 temperature=1.0,
324                 deterministic_synthesis=True,
325                 progress_bar_desc="solving quizzes",
326                 device=self.device,
327             )
328
329             correct = (quizzes == result).long().min(dim=-1).values
330
331             reverse_result = reverse_quizzes.clone()
332
333             masked_inplace_autoregression(
334                 model=m,
335                 batch_size=self.batch_size,
336                 input=reverse_result,
337                 ar_mask=ar_mask,
338                 temperature=1.0,
339                 deterministic_synthesis=True,
340                 progress_bar_desc="solving reversed quizzes",
341                 device=self.device,
342             )
343
344             reverse_correct = (
345                 (reverse_quizzes == reverse_result).long().min(dim=-1).values
346             )
347
348             nb_correct.append((correct * reverse_correct)[None, :])
349
350         nb_correct = torch.cat(nb_correct, dim=0)
351
352         filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
353         with open(filename, "w") as f:
354             for k in nb_correct:
355                 f.write(f"{k}\n")
356
357         return quizzes, nb_correct.sum(dim=0), average_logits