3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, os, tqdm, warnings
10 import torch, torchvision
13 from torch.nn import functional as F
16 from mygpt import BracketedSequence
18 ######################################################################
20 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
21 # 1s where tokens should be generated. The others are kept
25 def one_batch_masked_inplace_autoregression(
31 deterministic_synthesis=False,
33 to_generate = (ar_mask.sum(0) > 0).nonzero()
35 if to_generate.min() > 0:
37 BracketedSequence(input, 0, to_generate.min())
38 ) # Needed to initialize the model's cache
39 for s in range(to_generate.min(), to_generate.max() + 1):
40 output = model(BracketedSequence(input, s, 1)).x
44 logits = (logits / temperature).log_softmax(dim=-1)
46 if deterministic_synthesis:
47 t_next = logits.argmax(-1)
49 dist = torch.distributions.categorical.Categorical(logits=logits)
50 t_next = dist.sample()
52 all_n = torch.arange(t_next.size(0))
53 seq_logproba += logits[all_n, t_next].sum(dim=-1)
55 input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
58 def masked_inplace_autoregression(
65 deterministic_synthesis,
66 forbidden_tokens=None,
68 progress_bar_desc=None,
69 device=torch.device("cpu"),
71 assert input.size() == ar_mask.size()
74 input.split(batch_size),
75 ar_mask.split(batch_size),
76 seq_logproba.split(batch_size),
79 if progress_bar_desc is not None:
83 desc=progress_bar_desc,
84 total=(input.size(0) + batch_size - 1) // batch_size,
87 with torch.autograd.no_grad():
91 for input, ar_mask, seq_logproba in batches:
92 one_batch_masked_inplace_autoregression(
96 seq_logproba=seq_logproba,
97 temperature=temperature,
98 deterministic_synthesis=deterministic_synthesis,
104 ######################################################################
108 def make_ar_mask(self, input, first, nb):
109 i = torch.arange(input.size(1), device=input.device)
110 b = torch.logical_and(i >= first, i < first + nb)
111 return b.long()[None, :].expand_as(input)
113 def generate_token_sequences(self, nb):
114 prompts, answers = self.problem.generate_prompts_and_answers(nb)
116 if self.prompt_len is None:
117 self.prompt_len = prompts.size(1)
119 if self.answer_len is None:
120 self.answer_len = answers.size(1)
122 assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len
126 for prompt, answer in zip(prompts, answers):
127 if torch.rand(1) < 0.5:
129 torch.tensor([self.token_forward]),
131 torch.tensor([self.token_forward]),
136 torch.tensor([self.token_backward]),
138 torch.tensor([self.token_backward]),
142 result.append(torch.cat(a, dim=0)[None, :])
144 return torch.cat(result, dim=0)
154 device=torch.device("cpu"),
158 v = problem.nb_token_values()
159 self.token_forward = v
160 self.token_backward = v + 1
161 self.nb_token_values = v + 2
163 self.problem = problem
164 self.batch_size = batch_size
167 self.prompt_len = None
168 self.answer_len = None
170 self.train_w_quizzes = self.generate_token_sequences(nb_train_samples).to(
174 self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
176 self.train_c_quizzes = []
177 self.test_c_quizzes = []
179 if result_dir is not None:
181 result_dir, "culture_w_quizzes", self.train_w_quizzes[:72]
184 def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False):
185 l = (quizzes.size(1) - 1) // 2
186 forward = (quizzes[:, 0] == self.token_forward).long()
187 backward = (quizzes[:, 0] == self.token_backward).long()
188 assert forward.equal(1 - backward)
189 first = quizzes[:, 1 : 1 + l]
190 second = quizzes[:, 1 + l : 1 + 2 * l]
191 prompts = forward[:, None] * first + backward[:, None] * second
192 answers = forward[:, None] * second + backward[:, None] * first
195 predicted_prompts = backward
196 predicted_answers = forward
198 predicted_prompts = None
199 predicted_answers = None
201 self.problem.save_quizzes(
210 def batches(self, split="train", desc=None):
211 assert split in {"train", "test"}
213 w_quizzes = self.train_w_quizzes
214 c_quizzes = self.train_c_quizzes
216 w_quizzes = self.test_w_quizzes
217 c_quizzes = self.test_c_quizzes
219 if len(c_quizzes) > 0:
220 c_quizzes = torch.cat(c_quizzes, dim=0)
221 if c_quizzes.size(0) > w_quizzes.size(0) // 2:
222 i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
223 c_quizzes = c_quizzes[i]
225 i = torch.randperm(w_quizzes.size(0))[
226 : w_quizzes.size(0) - c_quizzes.size(0)
228 w_quizzes = w_quizzes[i]
230 self.nb_batch_w_quizzes = w_quizzes.size(0)
231 self.nb_batch_c_quizzes = c_quizzes.size(0)
233 input = torch.cat([w_quizzes, c_quizzes], dim=0)
236 self.nb_batch_w_quizzes = w_quizzes.size(0)
237 self.nb_batch_c_quizzes = 0
240 input = input[torch.randperm(input.size(0))]
243 desc = f"epoch-{split}"
244 for batch in tqdm.tqdm(
245 input.split(self.batch_size), dynamic_ncols=True, desc=desc
249 def vocabulary_size(self):
250 return self.nb_token_values
253 self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
255 def compute_accuracy(input):
257 ar_mask = self.make_ar_mask(input, 2 + self.prompt_len, self.answer_len)
258 result = input.clone() * (1 - ar_mask)
259 seq_logproba = torch.empty(input.size(0), device=self.device)
261 masked_inplace_autoregression(
263 batch_size=self.batch_size,
266 seq_logproba=seq_logproba,
268 deterministic_synthesis=deterministic_synthesis,
269 progress_bar_desc=None,
273 nb_total = input.size(0)
274 nb_correct = (input == result).long().min(dim=1).values.sum()
276 return nb_total, nb_correct
278 train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
281 f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
284 test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes)
287 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
290 main_test_accuracy = test_nb_correct / test_nb_total
291 self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
293 ##############################
295 input = self.test_w_quizzes[:96]
296 ar_mask = self.make_ar_mask(input, 2 + self.prompt_len, self.answer_len)
297 result = input.clone() * (1 - ar_mask)
298 seq_logproba = torch.empty(input.size(0), device=self.device)
300 masked_inplace_autoregression(
302 batch_size=self.batch_size,
305 seq_logproba=seq_logproba,
307 deterministic_synthesis=deterministic_synthesis,
308 progress_bar_desc=None,
314 f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
319 return main_test_accuracy
321 def renew_w_quizzes(self, nb, for_train=True):
322 input = self.train_w_quizzes if for_train else self.test_w_quizzes
323 nb = min(nb, input.size(0))
324 input[:-nb] = input[nb:].clone()
325 input[-nb:] = self.generate_token_sequences(nb).to(self.device)
327 def store_c_quizzes(self, new_c_quizzes, for_train=True):
329 self.train_c_quizzes.append(new_c_quizzes)
331 self.test_c_quizzes.append(new_c_quizzes)
333 def forward_to_backward(self, c_quizzes):
334 prompts = c_quizzes[:, 1 : 1 + self.prompt_len]
335 answers = c_quizzes[:, 2 + self.prompt_len :]
338 c_quizzes.new_full((c_quizzes, 1), self.token_backward),
340 c_quizzes.new_full((c_quizzes, 1), self.token_backward),
346 def backward_to_forward(self, c_quizzes):
347 answers = c_quizzes[:, 1 : 1 + self.answer_len :]
348 prompts = c_quizzes[:, 2 + self.answer_len :]
351 c_quizzes.new_full((c_quizzes.size(0), 1), self.token_forward),
353 c_quizzes.new_full((c_quizzes.size(0), 1), self.token_forward),
359 def compute_correctness(
362 models_for_validation,
363 bidirectional_validation=False,
364 deterministic_validation=True,
366 if bidirectional_validation:
367 backward_c_quizzes = self.forward_to_backward(c_quizzes)
369 seq_logproba = torch.zeros(
371 max([m.id for m in models_for_validation]) + 1,
377 for model in models_for_validation:
378 result = c_quizzes.clone()
380 seq_logproba[...] = 0.0
382 ar_mask = self.make_ar_mask(result, 2 + self.prompt_len, self.answer_len)
384 masked_inplace_autoregression(
386 batch_size=self.batch_size,
389 seq_logproba=seq_logproba[:, model.id],
391 deterministic_synthesis=deterministic_validation,
392 # progress_bar_desc="solving c_quizzes",
396 correct = (c_quizzes == result).long().min(dim=-1).values
398 if bidirectional_validation:
399 backward_result = backward_c_quizzes.clone()
401 ar_mask = self.make_ar_mask(
402 backward_result, 2 + self.answer_len, self.prompt_len
405 masked_inplace_autoregression(
407 batch_size=self.batch_size,
408 input=backward_result,
410 seq_logproba=seq_logproba[:, model.id],
412 deterministic_synthesis=deterministic_validation,
413 # progress_bar_desc="solving backward c_quizzes",
418 (backward_c_quizzes == backward_result).long().min(dim=-1).values
421 correct *= backward_correct
425 nb_correct += correct
427 return nb_correct, seq_logproba
429 ###############################################################
431 def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
432 c_quizzes = torch.empty(
433 nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
436 ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device)
437 ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1
438 ar_mask_second = 1 - ar_mask_first
439 ar_mask_first[:, 0] = 0
440 ar_mask_second[:, 0] = 0
442 seq_logproba = torch.zeros(ar_mask_first.size(0), device=self.device)
444 # First, we generate the answer at high temperature
446 c_quizzes[:, 0] = self.token_backward
448 masked_inplace_autoregression(
449 model=model_for_generation,
450 batch_size=self.batch_size,
452 ar_mask=ar_mask_first,
453 seq_logproba=seq_logproba,
454 temperature=temperature,
455 deterministic_synthesis=False,
459 # Then, we generate the prompt at low temperature
461 masked_inplace_autoregression(
462 model=model_for_generation,
463 batch_size=self.batch_size,
465 ar_mask=ar_mask_second,
466 seq_logproba=seq_logproba,
467 temperature=1 / temperature,
468 deterministic_synthesis=False,
472 # Then we return the quizz, and re-generate the response, now
475 c_quizzes = self.backward_to_forward(c_quizzes)
477 masked_inplace_autoregression(
478 model=model_for_generation,
479 batch_size=self.batch_size,
481 ar_mask=ar_mask_second,
482 seq_logproba=seq_logproba,
483 temperature=1 / temperature,
484 deterministic_synthesis=False,