Update.
[culture.git] / tasks.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, os, tqdm, warnings
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 from mygpt import BracketedSequence
16
17 # from graph import save_attention_image
18 save_attention_image = None
19
20 ######################################################################
21
22
23 def masked_inplace_autoregression(
24     model,
25     batch_size,
26     input,
27     ar_mask,
28     deterministic_synthesis,
29     forbidden_tokens=None,
30     logit_biases=None,
31     progress_bar_desc="autoregression",
32     device=torch.device("cpu"),
33 ):
34     assert input.size() == ar_mask.size()
35
36     batches = zip(input.split(batch_size), ar_mask.split(batch_size))
37
38     if progress_bar_desc is not None:
39         batches = tqdm.tqdm(
40             batches,
41             dynamic_ncols=True,
42             desc=progress_bar_desc,
43             total=(input.size(0) + batch_size - 1) // batch_size,
44         )
45
46     with torch.autograd.no_grad():
47         t = model.training
48         model.eval()
49
50         for input, ar_mask in batches:
51             model.masked_inplace_autoregression(
52                 input,
53                 ar_mask,
54                 deterministic_synthesis,
55                 forbidden_tokens,
56                 logit_biases,
57             )
58
59         model.train(t)
60
61
62 ######################################################################
63
64
65 class Task:
66     def batches(self, split="train", nb_to_use=-1, desc=None):
67         pass
68
69     def vocabulary_size(self):
70         pass
71
72     def produce_results(
73         self, n_epoch, model, result_dir, logger, deterministic_synthesis
74     ):
75         pass
76
77
78 ######################################################################
79
80 import world
81
82
83 class World(Task):
84     def save_image(self, input, result_dir, filename, logger):
85         img = world.sample2img(input.to("cpu"), self.height, self.width)
86         image_name = os.path.join(result_dir, filename)
87         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
88         logger(f"wrote {image_name}")
89
90     def make_ar_mask(self, input):
91         b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
92         return b.long()[None, :].expand_as(input)
93
94     def __init__(
95         self,
96         nb_train_samples,
97         nb_test_samples,
98         batch_size,
99         result_dir=None,
100         logger=None,
101         device=torch.device("cpu"),
102     ):
103         super().__init__()
104
105         self.batch_size = batch_size
106         self.device = device
107         self.height = 6
108         self.width = 8
109
110         self.train_input = world.generate(
111             nb_train_samples, height=self.height, width=self.width
112         ).to(device)
113
114         self.test_input = world.generate(
115             nb_test_samples, height=self.height, width=self.width
116         ).to(device)
117
118         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
119
120         self.train_quizzes = []
121         self.test_quizzes = []
122
123         if result_dir is not None:
124             self.save_image(
125                 self.train_input[:96], result_dir, f"world_train.png", logger
126             )
127
128     def batches(self, split="train", desc=None):
129         assert split in {"train", "test"}
130         if split == "train":
131             input = self.train_input
132             quizzes = self.train_quizzes
133         else:
134             input = self.test_input
135             quizzes = self.test_quizzes
136
137         if len(quizzes) > 0:
138             quizzes = torch.cat(quizzes, dim=0)
139             if quizzes.size(0) > input.size(0) // 2:
140                 i = torch.randperm(input.size(0))[: input.size(0) // 2]
141                 quizzes = quizzes[i]
142
143             i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
144             input = input[i]
145
146             self.nb_batch_samples_world = input.size(0)
147             self.nb_batch_samples_quizzes = quizzes.size(0)
148
149             input = torch.cat([input, quizzes], dim=0)
150         else:
151             self.nb_batch_samples_world = input.size(0)
152             self.nb_batch_samples_quizzes = 0
153
154         if desc is None:
155             desc = f"epoch-{split}"
156         for batch in tqdm.tqdm(
157             input.split(self.batch_size), dynamic_ncols=True, desc=desc
158         ):
159             yield batch
160
161     def vocabulary_size(self):
162         return self.nb_codes
163
164     def produce_results(
165         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
166     ):
167         def compute_accuracy(input, logger=None):
168             input = input[:nmax]
169             ar_mask = self.make_ar_mask(input)
170             result = input.clone() * (1 - ar_mask)
171
172             masked_inplace_autoregression(
173                 model,
174                 self.batch_size,
175                 result,
176                 ar_mask,
177                 deterministic_synthesis,
178                 progress_bar_desc=None,
179                 device=self.device,
180             )
181
182             nb_total, nb_correct = (
183                 input.size(0),
184                 (input == result).long().min(dim=1).values.sum(),
185             )
186
187             return nb_total, nb_correct
188
189         train_nb_total, train_nb_correct = compute_accuracy(self.train_input)
190
191         logger(
192             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
193         )
194
195         test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger)
196
197         logger(
198             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
199         )
200
201         main_test_accuracy = test_nb_correct / test_nb_total
202         logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
203
204         ##############################
205
206         input = self.test_input[:96]
207         ar_mask = self.make_ar_mask(input)
208         result = input.clone() * (1 - ar_mask)
209
210         masked_inplace_autoregression(
211             model,
212             self.batch_size,
213             result,
214             ar_mask,
215             deterministic_synthesis,
216             progress_bar_desc=None,
217             device=self.device,
218         )
219
220         self.save_image(
221             result[:96],
222             result_dir,
223             f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
224             logger,
225         )
226
227         return main_test_accuracy
228
229     def store_new_quizzes(self, new_quizzes, for_train=True):
230         if for_train:
231             self.train_quizzes.append(new_quizzes)
232         else:
233             self.test_quizzes.append(new_quizzes)
234
235     def create_new_quizzes(
236         self,
237         n_epoch,
238         result_dir,
239         logger,
240         nb,
241         model,
242         other_models,
243     ):
244         new_quizzes = torch.empty(
245             nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
246         )
247         ar_mask = torch.full(new_quizzes.size(), 1, device=self.device)
248
249         masked_inplace_autoregression(
250             model,
251             self.batch_size,
252             new_quizzes,
253             ar_mask,
254             deterministic_synthesis=False,
255             progress_bar_desc="new quizzes",
256             device=self.device,
257         )
258
259         ar_mask = self.make_ar_mask(new_quizzes)
260
261         nb_correct = 0
262
263         for m in other_models:
264             result = new_quizzes.clone()
265
266             masked_inplace_autoregression(
267                 m,
268                 self.batch_size,
269                 result,
270                 ar_mask,
271                 deterministic_synthesis=True,
272                 progress_bar_desc="solving quizzes",
273                 device=self.device,
274             )
275
276             l = self.height * self.width
277             direction = new_quizzes[:, l : l + 1]
278             direction = world.token_forward * (
279                 direction == world.token_backward
280             ) + world.token_backward * (direction == world.token_forward)
281             inverted_quizzes = torch.cat(
282                 [new_quizzes[:, l + 1 :], direction, new_quizzes[:, :l]], dim=1
283             )
284
285             inverted_result = inverted_quizzes.clone()
286
287             masked_inplace_autoregression(
288                 m,
289                 self.batch_size,
290                 inverted_result,
291                 ar_mask,
292                 deterministic_synthesis=True,
293                 progress_bar_desc="solving reverse quizzes",
294                 device=self.device,
295             )
296
297             nb_correct += (new_quizzes == result).long().min(dim=-1).values * (
298                 inverted_quizzes == inverted_result
299             ).long().min(dim=-1).values
300
301         return new_quizzes, nb_correct