X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=quiz_machine.py;h=bc468d3d4cb3ce7d1aea7ef127773c4d494cdd32;hb=refs%2Fheads%2Fmaster;hp=45b2247f50519d382bc5e04bbab5612fb60fe698;hpb=719785dbea77989a54bf7592bb6919f2e8f3f6c5;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py deleted file mode 100755 index 45b2247..0000000 --- a/quiz_machine.py +++ /dev/null @@ -1,528 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -import math, os, tqdm, warnings - -import torch, torchvision - -from torch import nn -from torch.nn import functional as F - -import mygpt -from mygpt import BracketedSequence - -###################################################################### - -# ar_mask is a tensor with 0s and 1s, of same shape as input, with -# 1s where tokens should be generated. The others are kept -# unchanged. - - -def one_batch_masked_inplace_autoregression( - model, - input, - ar_mask, - seq_logproba, - temperature=1.0, - deterministic_synthesis=False, -): - to_generate = (ar_mask.sum(0) > 0).nonzero() - - if to_generate.min() > 0: - model( - BracketedSequence(input, 0, to_generate.min()) - ) # Needed to initialize the model's cache - for s in range(to_generate.min(), to_generate.max() + 1): - output = model(BracketedSequence(input, s, 1)).x - - logits = output[:, s] - - logits = (logits / temperature).log_softmax(dim=-1) - - if deterministic_synthesis: - t_next = logits.argmax(-1) - else: - dist = torch.distributions.categorical.Categorical(logits=logits) - t_next = dist.sample() - - all_n = torch.arange(t_next.size(0)) - - seq_logproba += logits[all_n, t_next] - - input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] - - -def masked_inplace_autoregression( - model, - batch_size, - input, - ar_mask, - seq_logproba, - temperature, - deterministic_synthesis, - forbidden_tokens=None, - logit_biases=None, - progress_bar_desc=None, - device=torch.device("cpu"), -): - assert input.size() == ar_mask.size() - - batches = zip( - input.split(batch_size), - ar_mask.split(batch_size), - seq_logproba.split(batch_size), - ) - - if progress_bar_desc is not None: - batches = tqdm.tqdm( - batches, - dynamic_ncols=True, - desc=progress_bar_desc, - total=(input.size(0) + batch_size - 1) // batch_size, - ) - - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, ar_mask, seq_logproba in batches: - one_batch_masked_inplace_autoregression( - model=model, - input=input, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=deterministic_synthesis, - ) - - model.train(t) - - -###################################################################### - - -class QuizMachine: - def indices_forward_and_backward(self, quizzes): - i_forward = quizzes[:, 0] == self.token_forward - j_forward = quizzes[:, 1 + self.prompt_len] == self.token_forward - i_backward = quizzes[:, 0] == self.token_backward - j_backward = quizzes[:, 1 + self.answer_len] == self.token_backward - assert torch.logical_or( - torch.logical_and(i_forward, j_forward), - torch.logical_and(i_backward, j_backward), - ).all() - return i_forward, i_backward - - def reverse_time(self, quizzes): - i_forward, i_backward = self.indices_forward_and_backward(quizzes) - - forward_to_backward = torch.cat( - [ - quizzes[:, 0:1], - quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len], - quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1], - quizzes[:, 1 : 1 + self.prompt_len], - ], - dim=1, - ) - - forward_to_backward[:, 0] = self.token_backward - forward_to_backward[:, 1 + self.answer_len] = self.token_backward - - backward_to_forward = torch.cat( - [ - quizzes[:, 0:1], - quizzes[:, 2 + self.answer_len :], - quizzes[:, 1 + self.answer_len : 2 + self.answer_len], - quizzes[:, 1 : 1 + self.answer_len], - ], - dim=1, - ) - - backward_to_forward[:, 0] = self.token_forward - backward_to_forward[:, 1 + self.prompt_len] = self.token_forward - - m = i_forward.long()[:, None] - - return m * forward_to_backward + (1 - m) * backward_to_forward - - def reverse_random_half_in_place(self, quizzes): - i = torch.rand(quizzes.size(0)) < 0.5 - if i.any(): - quizzes[i] = self.reverse_time(quizzes[i]) - - def make_ar_mask(self, quizzes, first=False): - i_forward, i_backward = self.indices_forward_and_backward(quizzes) - - t = torch.arange(quizzes.size(1), device=quizzes.device) - - if first: - m_forward = (t >= 1).long() * (t < 1 + self.prompt_len).long() - m_backward = (t >= 1).long() * (t < 1 + self.answer_len).long() - else: - m_forward = (t >= 2 + self.prompt_len).long() - m_backward = (t >= 2 + self.answer_len).long() - - m = i_forward.long()[:, None] - - return m * m_forward + (1 - m) * m_backward - - def generate_token_sequences(self, nb): - prompts, answers = self.problem.generate_prompts_and_answers(nb) - - if self.prompt_len is None: - self.prompt_len = prompts.size(1) - - if self.answer_len is None: - self.answer_len = answers.size(1) - - assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len - - result = [] - - for prompt, answer in zip(prompts, answers): - a = [ - torch.tensor([self.token_forward]), - prompt, - torch.tensor([self.token_forward]), - answer, - ] - - result.append(torch.cat(a, dim=0)[None, :]) - - return torch.cat(result, dim=0) - - def __init__( - self, - problem, - nb_train_samples, - nb_test_samples, - back_accuracy, - batch_size, - result_dir, - logger, - device=torch.device("cpu"), - ): - super().__init__() - - v = problem.nb_token_values() - self.token_forward = v - self.token_backward = v + 1 - self.nb_token_values = v + 2 - - self.problem = problem - self.back_accuracy = back_accuracy - self.batch_size = batch_size - self.device = device - self.logger = logger - self.prompt_len = None - self.answer_len = None - - self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) - self.reverse_random_half_in_place(self.train_w_quizzes) - self.train_w_quizzes = self.train_w_quizzes.to(device) - - self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) - self.reverse_random_half_in_place(self.test_w_quizzes) - self.test_w_quizzes = self.test_w_quizzes.to(device) - - self.train_c_quizzes = [] - self.test_c_quizzes = [] - - if result_dir is not None: - self.save_quizzes( - result_dir, - "culture_w_quizzes", - self.train_w_quizzes[:72], - ) - - def save_quizzes( - self, - result_dir, - filename_prefix, - quizzes, - mistakes=None, - ): - quizzes = quizzes.clone() - n_forward = quizzes[quizzes[:, 0] == self.token_forward] - n_backward = quizzes[:, 0] == self.token_backward - backward = quizzes[n_backward] - assert n_forward.size(0) + backward.size(0) == quizzes.size(0) - quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) - - predicted_prompts = n_backward.long() - predicted_answers = 1 - predicted_prompts - if mistakes is not None: - # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct - predicted_prompts *= mistakes - predicted_answers *= mistakes - else: - # 0/2 ~ not-to-predict / to predict - predicted_prompts *= 2 - predicted_answers *= 2 - - self.problem.save_quizzes( - result_dir, - filename_prefix, - quizzes[:, 1 : 1 + self.prompt_len], - quizzes[:, 2 + self.prompt_len :], - predicted_prompts, - predicted_answers, - ) - - def batches(self, split="train", desc=None): - assert split in {"train", "test"} - if split == "train": - w_quizzes = self.train_w_quizzes - c_quizzes = self.train_c_quizzes - else: - w_quizzes = self.test_w_quizzes - c_quizzes = self.test_c_quizzes - - if len(c_quizzes) > 0: - c_quizzes = torch.cat(c_quizzes, dim=0) - if c_quizzes.size(0) > w_quizzes.size(0) // 2: - i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] - c_quizzes = c_quizzes[i] - - i = torch.randperm(w_quizzes.size(0))[ - : w_quizzes.size(0) - c_quizzes.size(0) - ] - w_quizzes = w_quizzes[i] - - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = c_quizzes.size(0) - - input = torch.cat([w_quizzes, c_quizzes], dim=0) - else: - input = w_quizzes - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = 0 - - # Shuffle - input = input[torch.randperm(input.size(0))] - - if desc is None: - desc = f"epoch-{split}" - for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=desc - ): - yield batch - - def vocabulary_size(self): - return self.nb_token_values - - def produce_results( - self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 - ): - def compute_accuracy(input, log_prefix=None): - ar_mask = self.make_ar_mask(input) - result = input.clone() * (1 - ar_mask) - seq_logproba = torch.empty(input.size(0), device=self.device) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - temperature=1.0, - deterministic_synthesis=deterministic_synthesis, - progress_bar_desc=None, - device=self.device, - ) - - correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device) - - n_forward = input[:, 0] == self.token_forward - n_backward = input[:, 0] == self.token_backward - - correct[n_forward] = ( - (input[n_forward] == result[n_forward]).long().min(dim=1).values - ) - - if self.back_accuracy and n_backward.any(): - # accuracy of B->A*->B*=B instead of B->A*=A - back_input = self.reverse_time(result[n_backward]) - back_input[:, 2 + self.prompt_len :] = input[ - n_backward, 1 : 1 + self.answer_len - ] - _, correct[n_backward] = compute_accuracy(back_input) - - if log_prefix is not None: - forward_nb_correct = correct[n_forward].sum() - forward_nb_total = correct[n_forward].size(0) - backward_nb_correct = correct[n_backward].sum() - backward_nb_total = correct[n_backward].size(0) - - self.logger( - f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)" - ) - - self.logger( - f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)" - ) - - return result, correct - - compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train") - - test_result, test_correct = compute_accuracy( - self.test_w_quizzes[:nmax], log_prefix="test" - ) - - main_test_accuracy = test_correct.sum() / test_correct.size(0) - self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}") - - ############################## - - self.save_quizzes( - result_dir, - f"culture_prediction_{n_epoch:04d}_{model.id:02d}", - quizzes=test_result[:72], - mistakes=test_correct[:72] * 2 - 1, - ) - - return main_test_accuracy - - def renew_w_quizzes(self, nb, for_train=True): - input = self.train_w_quizzes if for_train else self.test_w_quizzes - nb = min(nb, input.size(0)) - input[:-nb] = input[nb:].clone() - fresh_w_quizzes = self.generate_token_sequences(nb) - self.reverse_random_half_in_place(fresh_w_quizzes) - input[-nb:] = fresh_w_quizzes.to(self.device) - - def store_c_quizzes(self, new_c_quizzes, for_train=True): - if for_train: - self.train_c_quizzes.append(new_c_quizzes) - else: - self.test_c_quizzes.append(new_c_quizzes) - - def compute_correctness( - self, - c_quizzes, - models_for_validation, - bidirectional_validation=False, - deterministic_validation=True, - ): - if bidirectional_validation: - backward_c_quizzes = self.forward_to_backward(c_quizzes) - - seq_logproba = torch.zeros( - c_quizzes.size(0), - max([m.id for m in models_for_validation]) + 1, - device=self.device, - ) - - nb_correct = 0 - - seq_logproba[...] = 0.0 - - for model in models_for_validation: - result = c_quizzes.clone() - - ar_mask = self.make_ar_mask(result) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - seq_logproba=seq_logproba[:, model.id], - temperature=1.0, - deterministic_synthesis=deterministic_validation, - # progress_bar_desc="solving c_quizzes", - device=self.device, - ) - - correct = (c_quizzes == result).long().min(dim=-1).values - - if bidirectional_validation: - backward_result = backward_c_quizzes.clone() - - ar_mask = self.make_ar_mask(backward_result) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=backward_result, - ar_mask=ar_mask, - seq_logproba=seq_logproba[:, model.id], - temperature=1.0, - deterministic_synthesis=deterministic_validation, - # progress_bar_desc="solving backward c_quizzes", - device=self.device, - ) - - backward_correct = ( - (backward_c_quizzes == backward_result).long().min(dim=-1).values - ) - - correct *= backward_correct - - # endif - - nb_correct += correct - - return nb_correct, seq_logproba - - ############################################################### - - def generate_quizzes(self, nb, model_for_generation, temperature=1.0): - c_quizzes = torch.empty( - nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 - ) - - seq_logproba = torch.zeros(nb, device=self.device) - - # First, we generate the answer at high temperature - - c_quizzes[:, 0] = self.token_backward - c_quizzes[:, 1 + self.answer_len] = self.token_backward - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes, first=True), - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=False, - device=self.device, - ) - - # Then, we generate the prompt at low temperature - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes), - seq_logproba=seq_logproba, - temperature=1 / temperature, - deterministic_synthesis=False, - device=self.device, - ) - - # Then we return the quizz, and re-generate the response, now - # at low temperature - - c_quizzes = self.reverse_time(c_quizzes) - - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes), - seq_logproba=seq_logproba, - temperature=1 / temperature, - deterministic_synthesis=False, - device=self.device, - ) - - return c_quizzes