3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, os, tqdm, warnings
10 import torch, torchvision
13 from torch.nn import functional as F
16 from mygpt import BracketedSequence
18 ######################################################################
20 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
21 # 1s where tokens should be generated. The others are kept
25 def one_batch_masked_inplace_autoregression(
31 deterministic_synthesis=False,
33 to_generate = (ar_mask.sum(0) > 0).nonzero()
35 if to_generate.min() > 0:
37 BracketedSequence(input, 0, to_generate.min())
38 ) # Needed to initialize the model's cache
39 for s in range(to_generate.min(), to_generate.max() + 1):
40 output = model(BracketedSequence(input, s, 1)).x
44 logits = (logits / temperature).log_softmax(dim=-1)
46 if deterministic_synthesis:
47 t_next = logits.argmax(-1)
49 dist = torch.distributions.categorical.Categorical(logits=logits)
50 t_next = dist.sample()
52 all_n = torch.arange(t_next.size(0))
53 seq_logproba += logits[all_n, t_next].sum(dim=-1)
55 input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
58 def masked_inplace_autoregression(
65 deterministic_synthesis,
66 forbidden_tokens=None,
68 progress_bar_desc=None,
69 device=torch.device("cpu"),
71 assert input.size() == ar_mask.size()
74 input.split(batch_size),
75 ar_mask.split(batch_size),
76 seq_logproba.split(batch_size),
79 if progress_bar_desc is not None:
83 desc=progress_bar_desc,
84 total=(input.size(0) + batch_size - 1) // batch_size,
87 with torch.autograd.no_grad():
91 for input, ar_mask, seq_logproba in batches:
92 one_batch_masked_inplace_autoregression(
96 seq_logproba=seq_logproba,
97 temperature=temperature,
98 deterministic_synthesis=deterministic_synthesis,
104 ######################################################################
108 def make_ar_mask(self, input):
109 b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
110 return b.long()[None, :].expand_as(input)
112 def generate_token_sequences(self, nb):
113 prompts, answers = self.problem.generate_prompts_and_answers(nb)
116 for prompt, answer in zip(prompts, answers):
117 if torch.rand(1) < 0.5:
118 a = [torch.tensor([self.token_forward]), prompt, answer]
120 a = [torch.tensor([self.token_backward]), answer, prompt]
122 result.append(torch.cat(a, dim=0)[None, :])
124 return torch.cat(result, dim=0)
134 device=torch.device("cpu"),
138 v = problem.nb_token_values()
139 self.token_forward = v
140 self.token_backward = v + 1
141 self.nb_token_values = v + 2
143 self.problem = problem
144 self.batch_size = batch_size
148 self.train_w_quizzes = self.generate_token_sequences(nb_train_samples).to(
152 self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
154 self.train_c_quizzes = []
155 self.test_c_quizzes = []
157 if result_dir is not None:
159 result_dir, "culture_w_quizzes", self.train_w_quizzes[:72]
162 def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False):
163 l = (quizzes.size(1) - 1) // 2
164 forward = (quizzes[:, 0] == self.token_forward).long()
165 backward = (quizzes[:, 0] == self.token_backward).long()
166 assert forward.equal(1 - backward)
167 first = quizzes[:, 1 : 1 + l]
168 second = quizzes[:, 1 + l : 1 + 2 * l]
169 prompts = forward[:, None] * first + backward[:, None] * second
170 answers = forward[:, None] * second + backward[:, None] * first
173 predicted_prompts = backward
174 predicted_answers = forward
176 predicted_prompts = None
177 predicted_answers = None
179 self.problem.save_quizzes(
188 def batches(self, split="train", desc=None):
189 assert split in {"train", "test"}
191 w_quizzes = self.train_w_quizzes
192 c_quizzes = self.train_c_quizzes
194 w_quizzes = self.test_w_quizzes
195 c_quizzes = self.test_c_quizzes
197 if len(c_quizzes) > 0:
198 c_quizzes = torch.cat(c_quizzes, dim=0)
199 if c_quizzes.size(0) > w_quizzes.size(0) // 2:
200 i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
201 c_quizzes = c_quizzes[i]
203 i = torch.randperm(w_quizzes.size(0))[
204 : w_quizzes.size(0) - c_quizzes.size(0)
206 w_quizzes = w_quizzes[i]
208 self.nb_batch_w_quizzes = w_quizzes.size(0)
209 self.nb_batch_c_quizzes = c_quizzes.size(0)
211 input = torch.cat([w_quizzes, c_quizzes], dim=0)
214 self.nb_batch_w_quizzes = w_quizzes.size(0)
215 self.nb_batch_c_quizzes = 0
218 input = input[torch.randperm(input.size(0))]
221 desc = f"epoch-{split}"
222 for batch in tqdm.tqdm(
223 input.split(self.batch_size), dynamic_ncols=True, desc=desc
227 def vocabulary_size(self):
228 return self.nb_token_values
231 self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
233 def compute_accuracy(input):
235 ar_mask = self.make_ar_mask(input)
236 result = input.clone() * (1 - ar_mask)
237 seq_logproba = torch.empty(input.size(0), device=self.device)
239 masked_inplace_autoregression(
241 batch_size=self.batch_size,
244 seq_logproba=seq_logproba,
246 deterministic_synthesis=deterministic_synthesis,
247 progress_bar_desc=None,
251 nb_total, nb_correct = (
253 (input == result).long().min(dim=1).values.sum(),
256 return nb_total, nb_correct
258 train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
261 f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
264 test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes)
267 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
270 main_test_accuracy = test_nb_correct / test_nb_total
271 self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
273 ##############################
275 input = self.test_w_quizzes[:96]
276 ar_mask = self.make_ar_mask(input)
277 result = input.clone() * (1 - ar_mask)
278 seq_logproba = torch.empty(input.size(0), device=self.device)
280 masked_inplace_autoregression(
282 batch_size=self.batch_size,
285 seq_logproba=seq_logproba,
287 deterministic_synthesis=deterministic_synthesis,
288 progress_bar_desc=None,
294 f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
299 return main_test_accuracy
301 def renew_w_quizzes(self, nb, for_train=True):
302 input = self.train_w_quizzes if for_train else self.test_w_quizzes
303 nb = min(nb, input.size(0))
304 input[:-nb] = input[nb:].clone()
305 input[-nb:] = self.generate_token_sequences(nb).to(self.device)
307 def store_c_quizzes(self, new_c_quizzes, for_train=True):
309 self.train_c_quizzes.append(new_c_quizzes)
311 self.test_c_quizzes.append(new_c_quizzes)
313 def reverse_time(self, c_quizzes):
314 l = (c_quizzes.size(1) - 1) // 2
315 direction = c_quizzes[:, 0:1]
316 direction = self.token_forward * (
317 direction == self.token_backward
318 ) + self.token_backward * (direction == self.token_forward)
321 [direction, c_quizzes[:, l + 1 :], c_quizzes[:, 1 : l + 1]], dim=1
324 def compute_correctness(
325 self, c_quizzes, models_for_validation, both_directions=False
327 reversed_c_quizzes = self.reverse_time(c_quizzes)
329 ar_mask = self.make_ar_mask(c_quizzes)
330 seq_logproba = torch.zeros(
332 max([m.id for m in models_for_validation]) + 1,
336 # Check how many of models can solve the quizzes in both directions
340 for model in models_for_validation:
341 result = c_quizzes.clone()
343 seq_logproba[...] = 0.0
345 masked_inplace_autoregression(
347 batch_size=self.batch_size,
350 seq_logproba=seq_logproba[:, model.id],
352 deterministic_synthesis=True,
353 # progress_bar_desc="solving c_quizzes",
357 correct = (c_quizzes == result).long().min(dim=-1).values
360 reversed_result = reversed_c_quizzes.clone()
362 masked_inplace_autoregression(
364 batch_size=self.batch_size,
365 input=reversed_result,
367 seq_logproba=seq_logproba[:, model.id],
369 deterministic_synthesis=True,
370 # progress_bar_desc="solving reversed c_quizzes",
375 (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
378 correct *= reversed_correct
382 nb_correct += correct
384 return nb_correct, seq_logproba
386 ###############################################################
388 def generate_quizzes(self, nb, model_for_generation):
389 c_quizzes = torch.empty(
390 nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
393 ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device)
394 ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1
395 ar_mask_second = 1 - ar_mask_first
396 ar_mask_first[:, 0] = 0
397 ar_mask_second[:, 0] = 0
399 seq_logproba = torch.zeros(ar_mask_first.size(0), device=self.device)
403 # First, we generate the answer at high temperature
405 c_quizzes[:, 0] = self.token_backward
407 masked_inplace_autoregression(
408 model=model_for_generation,
409 batch_size=self.batch_size,
411 ar_mask=ar_mask_first,
412 seq_logproba=seq_logproba,
413 temperature=temperature,
414 deterministic_synthesis=False,
418 # Then, we generate the prompt deterministically
420 masked_inplace_autoregression(
421 model=model_for_generation,
422 batch_size=self.batch_size,
424 ar_mask=ar_mask_second,
425 seq_logproba=seq_logproba,
427 deterministic_synthesis=True,
431 # Then we return the quizz, and re-generate the response, now
434 c_quizzes = self.reverse_time(c_quizzes)
436 masked_inplace_autoregression(
437 model=model_for_generation,
438 batch_size=self.batch_size,
440 ar_mask=ar_mask_second,
441 seq_logproba=seq_logproba,
442 temperature=temperature,
443 deterministic_synthesis=True,