From c1f6e8f7b71e0dad6466a3466526da6bcf5f201a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 21 Sep 2024 16:28:21 +0200 Subject: [PATCH] Update. --- grids.py | 6 +- quiz_machine.py | 430 ------------------------------------------------ 2 files changed, 3 insertions(+), 433 deletions(-) delete mode 100755 quiz_machine.py diff --git a/grids.py b/grids.py index ac25781..0f7e554 100755 --- a/grids.py +++ b/grids.py @@ -287,9 +287,9 @@ class Grids(problem.Problem): ###################################################################### def vocabulary_size(self): - warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning) - return self.nb_colors + 4 - # return self.nb_colors + # warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning) + # return self.nb_colors + 4 + return self.nb_colors def grid2img(self, x, scale=15, grids=True): m = torch.logical_and(x >= 0, x < self.nb_colors).long() diff --git a/quiz_machine.py b/quiz_machine.py deleted file mode 100755 index 5bab1e5..0000000 --- a/quiz_machine.py +++ /dev/null @@ -1,430 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -import math, os, tqdm, warnings, sys - -import torch, torchvision - -from torch import nn -from torch.nn import functional as F - -import mygpt -from mygpt import BracketedSequence - -import threading - -###################################################################### - -# ar_mask is a tensor with 0s and 1s, of same shape as input, with -# 1s where tokens should be generated. The others are kept -# unchanged. - - -def one_batch_masked_inplace_autoregression( - model, - input, - ar_mask, - seq_logproba, - deterministic_synthesis=False, -): - if input.size(0) == 0: - return - - to_generate = (ar_mask.sum(0) > 0).nonzero() - - if to_generate.min() > 0: - model( - BracketedSequence(input, 0, to_generate.min()) - ) # Needed to initialize the model's cache - for s in range(to_generate.min(), to_generate.max() + 1): - output = model(BracketedSequence(input, s, 1)).x - - logits = output[:, s] - - if deterministic_synthesis: - t_next = logits.argmax(-1) - else: - dist = torch.distributions.categorical.Categorical(logits=logits) - t_next = dist.sample() - - all_n = torch.arange(t_next.size(0)) - - seq_logproba += logits.log_softmax(dim=1)[all_n, t_next] - - input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] - - -###################################################################### - - -class QuizMachine: - def __init__( - self, - problem, - batch_size, - result_dir, - prompt_noise, - logger, - device=torch.device("cpu"), - ): - super().__init__() - - self.problem = problem - self.batch_size = batch_size - self.device = device - self.logger = logger - self.prompt_len = None - self.answer_len = None - self.prompt_noise = prompt_noise - - # struct, mask_generate, mask_noise, mask_loss - self.train_structures = [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), - (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), - (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), - (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), - (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), - # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)), - # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), - # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)), - # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), - # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), - ] - - self.test_structures = self.train_structures - - self.LOCK_C_QUIZZES = threading.Lock() - self.train_c_quizzes = [] - self.test_c_quizzes = [] - - def vocabulary_size(self): - return self.problem.nb_token_values - - ###################################################################### - - def autoregression( - self, - model, - input, - ar_mask, - seq_logproba=None, - progress_bar_desc=None, - ): - assert input.size() == ar_mask.size() - - if seq_logproba is None: - seq_logproba = torch.empty(input.size(0), device=self.device) - - batches = zip( - input.split(self.batch_size), - ar_mask.split(self.batch_size), - seq_logproba.split(self.batch_size), - ) - - if progress_bar_desc is not None: - batches = tqdm.tqdm( - batches, - dynamic_ncols=True, - desc=progress_bar_desc, - total=(input.size(0) + self.batch_size - 1) // self.batch_size, - ) - - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, ar_mask, seq_logproba in batches: - one_batch_masked_inplace_autoregression( - model=model, - input=input, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - deterministic_synthesis=False, - ) - - model.train(t) - - ###################################################################### - - def data_input(self, model, split="train"): - assert split in {"train", "test"} - - with self.LOCK_C_QUIZZES: - if split == "train": - w_quizzes = model.train_w_quizzes - c_quizzes = self.train_c_quizzes - else: - w_quizzes = model.test_w_quizzes - c_quizzes = self.test_c_quizzes - - if len(c_quizzes) > 0: - c_quizzes = torch.cat(c_quizzes, dim=0) - - if c_quizzes.size(0) > w_quizzes.size(0) // 2: - i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] - c_quizzes = c_quizzes[i] - - i = torch.randperm(w_quizzes.size(0))[ - : w_quizzes.size(0) - c_quizzes.size(0) - ] - w_quizzes = w_quizzes[i] - - quizzes = torch.cat([w_quizzes, c_quizzes], dim=0) - from_w = torch.arange( - quizzes.size(0), device=quizzes.device - ) < w_quizzes.size(0) - - else: - quizzes = w_quizzes.clone() - from_w = torch.full((quizzes.size(0),), True, device=quizzes.device) - - i = torch.randperm(quizzes.size(0), device=quizzes.device) - quizzes, from_w = quizzes[i], from_w[i] - - self.randomize_configuations_inplace( - quizzes, structs=[s for s, _, _, _ in self.train_structures] - ) - - quiz_mask_loss = quizzes.new_full(quizzes.size(), 1) - - if self.prompt_noise > 0.0: - for struct, _, mask_noise, mask_loss in self.train_structures: - i = self.problem.indices_select(quizzes=quizzes, struct=struct) - if i.any(): - quizzes[i] = self.problem.inject_noise( - quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise - ) - quiz_mask_loss[i] = self.make_quiz_mask( - quizzes=quizzes[i], struct=struct, mask=mask_loss - ) - - return quizzes, quiz_mask_loss - - ###################################################################### - - def make_quiz_mask(self, quizzes, struct, mask): - assert struct in [s for s, _, _, _ in self.train_structures] - return self.problem.make_quiz_mask(quizzes, struct=struct, mask=mask) - - ###################################################################### - - def predict(self, model, quizzes, struct, mask): - ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask) - result = quizzes * (1 - ar_mask) - - seq_logproba = torch.empty(quizzes.size(0), device=self.device) - - self.autoregression( - model=model, - input=result, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - progress_bar_desc="accuracy", - ) - - correct = (result == quizzes).min(dim=1).values.long() - - return result, correct - - ###################################################################### - - def produce_results(self, n_epoch, model, input, result_dir): - input = input.to(self.device) - result = input.new(input.size()) - correct = input.new(input.size(0)) - predicted_parts = input.new(input.size(0), 4) - - nb = 0 - - # We consider all the configurations that we train for - for struct, mask_generate, _, _ in self.test_structures: - i = self.problem.indices_select(quizzes=input, struct=struct) - nb += i.long().sum() - result[i], correct[i] = self.predict( - model=model, quizzes=input[i], struct=struct, mask=mask_generate - ) - predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[ - None, : - ] - solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1 - correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long() - - assert nb == input.size(0) - - nb_correct = (correct == 1).long().sum() - nb_total = (correct != 0).long().sum() - self.logger( - f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}" - ) - - main_test_accuracy = nb_correct / nb_total - - ############################## - - correct_parts = predicted_parts * correct[:, None] - - result = result[:128] - predicted_parts = predicted_parts[:128] - correct_parts = correct_parts[:128] - - self.problem.save_quizzes_as_image( - result_dir, - f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png", - quizzes=result, - predicted_parts=predicted_parts, - correct_parts=correct_parts, - ) - - return main_test_accuracy - - ###################################################################### - - def randomize_configuations_inplace(self, quizzes, structs): - r = torch.randint(len(structs), (quizzes.size(0),), device=quizzes.device) - for c in range(len(structs)): - quizzes[r == c] = self.problem.reconfigure( - quizzes[r == c], struct=structs[c] - ) - - ###################################################################### - - def renew_train_w_quizzes(self, model): - if hasattr(model, "hard_w_quizzes"): - hard_w_quizzes = self.problem.reconfigure( - model.hard_w_quizzes, struct=("A", "f_A", "B", "f_B") - ) - self.logger( - f"re-using {hard_w_quizzes.size(0)} hard world quizzes from model {model.id}" - ) - if hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0): - nb_to_generate = 0 - model.train_w_quizzes[...] = hard_w_quizzes[ - torch.randperm(hard_w_quizzes.size(0))[ - model.train_w_quizzes.size(0) - ] - ] - else: - nb_to_generate = model.train_w_quizzes.size(0) - hard_w_quizzes.size(0) - model.train_w_quizzes[...] = torch.cat( - [ - hard_w_quizzes, - self.problem.generate_w_quizzes(nb_to_generate), - ], - dim=0, - ) - else: - nb_to_generate = 0 - model.train_w_quizzes[...] = self.problem.generate_w_quizzes( - model.train_w_quizzes.size(0) - ) - - ###################################################################### - - def store_c_quizzes(self, new_c_quizzes, for_train=True): - with self.LOCK_C_QUIZZES: - if for_train: - self.train_c_quizzes.append(new_c_quizzes.to("cpu")) - else: - self.test_c_quizzes.append(new_c_quizzes.to("cpu")) - - def save_c_quizzes(self, filename): - torch.save((self.train_c_quizzes, self.test_c_quizzes), filename) - - def load_c_quizzes(self, filename): - self.train_c_quizzes, self.test_c_quizzes = torch.load(filename) - - ###################################################################### - - def models_logprobas( - self, - models_for_validation, - c_quizzes, - struct, - mask_loss, - mask_noise=None, - device=None, - ): - if device is None: - device = self.device - - c_quizzes = self.problem.reconfigure(c_quizzes, struct) - - seq_logproba = torch.zeros( - c_quizzes.size(0), - max([m.id for m in models_for_validation]) + 1, - device=device, - ) - - # if self.prompt_noise > 0.0 and mask_noise is not None: - # c_quizzes = self.problem.inject_noise( - # c_quizzes, self.prompt_noise, struct=struct, mask=mask_noise - # ) - - for model in models_for_validation: - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, l in zip( - c_quizzes.split(self.batch_size), - seq_logproba.split(self.batch_size), - ): - input = input.to(device) - quiz_mask_loss = self.make_quiz_mask( - input, struct=struct, mask=mask_loss - ) - output = model(mygpt.BracketedSequence(input)).x - l[:, model.id] = ( - -F.cross_entropy( - output.transpose(1, 2), input, reduction="none" - ) - * quiz_mask_loss - ).sum(dim=1) - - model.train(t) - - return seq_logproba.to("cpu") - - ###################################################################### - - def generate_c_quizzes(self, nb, model_for_generation, procedure, recorder=None): - seq_logproba = torch.zeros(nb, device=self.device) - - c_quizzes = None - - for s, m, mt in procedure: - if c_quizzes is None: - c_quizzes = self.problem.create_empty_quizzes(nb, s) - c_quizzes = c_quizzes.to(self.device) - elif s != pred_s: - c_quizzes = self.problem.reconfigure(c_quizzes, s) - pred_s = s - - if mt is not None: - mt(model_for_generation) - - self.autoregression( - model=model_for_generation, - input=c_quizzes, - ar_mask=self.make_quiz_mask(c_quizzes, s, m), - seq_logproba=seq_logproba, - ) - - model_for_generation.reset_transformations() - - if recorder is not None: - x = c_quizzes.clone() - t = torch.tensor(m, device=x.device)[None, :].expand(x.size(0), -1) - recorder.append( - self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B")) - ) - - c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) - - return c_quizzes.to("cpu") - - ###################################################################### -- 2.39.5