3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, os, tqdm, warnings
10 import torch, torchvision
13 from torch.nn import functional as F
15 from mygpt import BracketedSequence
17 ######################################################################
20 def masked_inplace_autoregression(
27 deterministic_synthesis,
28 forbidden_tokens=None,
30 progress_bar_desc="autoregression",
31 device=torch.device("cpu"),
33 assert input.size() == ar_mask.size()
35 batches = zip(input.split(batch_size), ar_mask.split(batch_size))
37 if progress_bar_desc is not None:
41 desc=progress_bar_desc,
42 total=(input.size(0) + batch_size - 1) // batch_size,
45 with torch.autograd.no_grad():
49 for input, ar_mask in batches:
50 model.masked_inplace_autoregression(
53 summed_logits=summed_logits,
54 temperature=temperature,
55 deterministic_synthesis=deterministic_synthesis,
56 forbidden_tokens=forbidden_tokens,
57 forced_biases=logit_biases,
63 ######################################################################
67 def batches(self, split="train", nb_to_use=-1, desc=None):
70 def vocabulary_size(self):
74 self, n_epoch, model, result_dir, logger, deterministic_synthesis
79 ######################################################################
85 def save_image(self, input, result_dir, filename, logger):
86 img = world.seq2img(input.to("cpu"), self.height, self.width)
87 image_name = os.path.join(result_dir, filename)
88 torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
89 logger(f"wrote {image_name}")
91 def save_quizzes(self, input, result_dir, filename_prefix, logger):
92 self.save_image(input, result_dir, filename_prefix + ".png", logger)
94 def make_ar_mask(self, input):
95 b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
96 return b.long()[None, :].expand_as(input)
105 device=torch.device("cpu"),
109 self.batch_size = batch_size
114 self.train_w_quizzes = world.generate_seq(
115 nb_train_samples, height=self.height, width=self.width
118 self.test_w_quizzes = world.generate_seq(
119 nb_test_samples, height=self.height, width=self.width
122 self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1
124 self.train_c_quizzes = []
125 self.test_c_quizzes = []
127 if result_dir is not None:
129 self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger
132 def batches(self, split="train", desc=None):
133 assert split in {"train", "test"}
135 w_quizzes = self.train_w_quizzes
136 c_quizzes = self.train_c_quizzes
138 w_quizzes = self.test_w_quizzes
139 c_quizzes = self.test_c_quizzes
141 if len(c_quizzes) > 0:
142 c_quizzes = torch.cat(c_quizzes, dim=0)
143 if c_quizzes.size(0) > w_quizzes.size(0) // 2:
144 i = torch.randperm(w_quizzes.size(0))[: w_quizzes.size(0) // 2]
145 c_quizzes = c_quizzes[i]
147 i = torch.randperm(w_quizzes.size(0))[
148 : w_quizzes.size(0) - c_quizzes.size(0)
150 w_quizzes = w_quizzes[i]
152 self.nb_batch_w_quizzes = w_quizzes.size(0)
153 self.nb_batch_c_quizzes = c_quizzes.size(0)
155 input = torch.cat([w_quizzes, c_quizzes], dim=0)
158 self.nb_batch_w_quizzes = w_quizzes.size(0)
159 self.nb_batch_c_quizzes = 0
162 input = input[torch.randperm(input.size(0))]
165 desc = f"epoch-{split}"
166 for batch in tqdm.tqdm(
167 input.split(self.batch_size), dynamic_ncols=True, desc=desc
171 def vocabulary_size(self):
175 self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
177 def compute_accuracy(input, logger=None):
179 ar_mask = self.make_ar_mask(input)
180 result = input.clone() * (1 - ar_mask)
182 masked_inplace_autoregression(
184 batch_size=self.batch_size,
189 deterministic_synthesis=deterministic_synthesis,
190 progress_bar_desc=None,
194 nb_total, nb_correct = (
196 (input == result).long().min(dim=1).values.sum(),
199 return nb_total, nb_correct
201 train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
204 f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
207 test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes, logger)
210 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
213 main_test_accuracy = test_nb_correct / test_nb_total
214 logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
216 ##############################
218 input = self.test_w_quizzes[:96]
219 ar_mask = self.make_ar_mask(input)
220 result = input.clone() * (1 - ar_mask)
222 masked_inplace_autoregression(
224 batch_size=self.batch_size,
229 deterministic_synthesis=deterministic_synthesis,
230 progress_bar_desc=None,
237 f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
241 return main_test_accuracy
243 def renew_w_quizzes(self, nb, for_train=True):
244 input = self.train_w_quizzes if for_train else self.test_w_quizzes
245 nb = min(nb, input.size(0))
246 input[:-nb] = input[nb:].clone()
247 input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
251 def store_c_quizzes(self, new_c_quizzes, for_train=True):
253 self.train_c_quizzes.append(new_c_quizzes)
255 self.test_c_quizzes.append(new_c_quizzes)
257 def create_c_quizzes(
265 desired_average_logits=None,
267 ###############################################################
268 # Generate quizzes with model
270 c_quizzes = torch.empty(
271 nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
274 ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
275 summed_logits = torch.empty(nb, device=self.device)
281 summed_logits[...] = 0
283 masked_inplace_autoregression(
285 batch_size=self.batch_size,
288 summed_logits=summed_logits,
289 temperature=temperature,
290 deterministic_synthesis=False,
291 progress_bar_desc="sampling c_quizzes",
295 average_logits = summed_logits.mean()
297 logger(f"{average_logits=} {desired_average_logits=}")
299 if desired_average_logits is None:
303 if average_logits < desired_average_logits * 1.1:
304 if d_temperature > 0:
305 d_temperature *= -0.5
306 temperature += d_temperature
307 elif average_logits > desired_average_logits:
308 if d_temperature < 0:
309 d_temperature *= -0.5
310 temperature += d_temperature
314 logger(f"chaging temperature to {temperature}")
316 ###############################################################
317 # Create the reverse quizzes
319 l = self.height * self.width
320 direction = c_quizzes[:, l : l + 1]
321 direction = world.token_forward * (
322 direction == world.token_backward
323 ) + world.token_backward * (direction == world.token_forward)
324 reverse_c_quizzes = torch.cat(
325 [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
328 ar_mask = self.make_ar_mask(c_quizzes)
330 ###############################################################
331 # Check how many of the other models can solve them in both
336 for m in other_models:
337 result = c_quizzes.clone()
339 masked_inplace_autoregression(
341 batch_size=self.batch_size,
346 deterministic_synthesis=True,
347 progress_bar_desc="solving c_quizzes",
351 correct = (c_quizzes == result).long().min(dim=-1).values
353 reverse_result = reverse_c_quizzes.clone()
355 masked_inplace_autoregression(
357 batch_size=self.batch_size,
358 input=reverse_result,
362 deterministic_synthesis=True,
363 progress_bar_desc="solving reversed c_quizzes",
368 (reverse_c_quizzes == reverse_result).long().min(dim=-1).values
371 nb_correct.append((correct * reverse_correct)[None, :])
373 nb_correct = torch.cat(nb_correct, dim=0).sum(dim=0)
375 # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
376 # with open(filename, "w") as f:
377 # for k in nb_correct:
380 return c_quizzes, nb_correct, summed_logits.mean()