X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=quiz_machine.py;h=bc468d3d4cb3ce7d1aea7ef127773c4d494cdd32;hb=refs%2Fheads%2Fmaster;hp=cdfba85e6a2abb8d6cd1b14bb3b0f76fe2afad30;hpb=693af34e144cd20d2dde6a508a190d49c1a76c7f;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index cdfba85..bc468d3 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, os, tqdm, warnings +import math, os, tqdm, warnings, sys import torch, torchvision @@ -15,6 +15,38 @@ from torch.nn import functional as F import mygpt from mygpt import BracketedSequence +import threading + +###################################################################### +# if output is log(P(X=y)) and target is Y, returns -log P(X=Y) + H(X +# | X != Y) + + +# output is NxCxT and target is NxT +def confusion(output, target, reduction="mean"): + N, C, T = output.shape + output = output.permute(0, 2, 1).reshape(-1, C) + target = target.flatten() + all_t = torch.arange(N * T, device=output.device) + output = output.log_softmax(dim=-1) + result = -output[all_t, target] + + output[all_t, target] = float("-inf") + output = output.log_softmax(dim=-1) + e = output.exp() + output[all_t, target] = 0 + result = result - (output * e).sum(-1) + + if reduction == "none": + return result.reshape(N, T) + elif reduction == "mean": + return result.reshape(N, T).mean() + elif reduction == "sum": + return result.reshape(N, T).sum() + else: + raise ValueError(f"unknown reduction '{reduction}'.") + + ###################################################################### # ar_mask is a tensor with 0s and 1s, of same shape as input, with @@ -235,23 +267,11 @@ class QuizMachine: self.prompt_len = None self.answer_len = None - # self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) - # self.reverse_random_half_in_place(self.train_w_quizzes) - - # self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) - # self.reverse_random_half_in_place(self.test_w_quizzes) - + self.LOCK_C_QUIZZES = threading.Lock() self.train_c_quizzes = [] self.test_c_quizzes = [] - # if result_dir is not None: - # self.save_quizzes( - # result_dir, - # "culture_w_quizzes", - # self.train_w_quizzes[:72], - # ) - - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -276,7 +296,7 @@ class QuizMachine: predicted_prompts *= 2 predicted_answers *= 2 - self.problem.save_quizzes( + self.problem.save_quiz_illustrations( result_dir, filename_prefix, quizzes[:, 1 : 1 + self.prompt_len], @@ -285,34 +305,41 @@ class QuizMachine: predicted_answers, ) + def vocabulary_size(self): + return self.nb_token_values + + ###################################################################### + def batches(self, model, split="train", desc=None): assert split in {"train", "test"} - if split == "train": - w_quizzes = model.train_w_quizzes - c_quizzes = self.train_c_quizzes - else: - w_quizzes = model.test_w_quizzes - c_quizzes = self.test_c_quizzes - if len(c_quizzes) > 0: - c_quizzes = torch.cat(c_quizzes, dim=0) - if c_quizzes.size(0) > w_quizzes.size(0) // 2: - i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] - c_quizzes = c_quizzes[i] - - i = torch.randperm(w_quizzes.size(0))[ - : w_quizzes.size(0) - c_quizzes.size(0) - ] - w_quizzes = w_quizzes[i] + with self.LOCK_C_QUIZZES: + if split == "train": + w_quizzes = model.train_w_quizzes + c_quizzes = self.train_c_quizzes + else: + w_quizzes = model.test_w_quizzes + c_quizzes = self.test_c_quizzes + + if len(c_quizzes) > 0: + c_quizzes = torch.cat(c_quizzes, dim=0) + if c_quizzes.size(0) > w_quizzes.size(0) // 2: + i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] + c_quizzes = c_quizzes[i] + + i = torch.randperm(w_quizzes.size(0))[ + : w_quizzes.size(0) - c_quizzes.size(0) + ] + w_quizzes = w_quizzes[i] - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = c_quizzes.size(0) + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = c_quizzes.size(0) - input = torch.cat([w_quizzes, c_quizzes], dim=0) - else: - input = w_quizzes - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = 0 + input = torch.cat([w_quizzes, c_quizzes], dim=0) + else: + input = w_quizzes + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = 0 # Shuffle input = input[torch.randperm(input.size(0))] @@ -324,13 +351,13 @@ class QuizMachine: ): yield batch - def vocabulary_size(self): - return self.nb_token_values + ###################################################################### def produce_results( self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 ): def compute_accuracy(input, log_prefix=None): + input = input.to(self.device) ar_mask = self.make_ar_mask(input) result = input.clone() * (1 - ar_mask) seq_logproba = torch.empty(input.size(0), device=self.device) @@ -371,16 +398,12 @@ class QuizMachine: backward_nb_total = correct[n_backward].size(0) self.logger( - f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)" - ) - - self.logger( - f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)" + f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}" ) return result, correct - compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") + # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") test_result, test_correct = compute_accuracy( model.test_w_quizzes[:nmax], log_prefix="test" @@ -391,7 +414,7 @@ class QuizMachine: ############################## - self.save_quizzes( + self.save_quiz_illustrations( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", quizzes=test_result[:72], @@ -400,36 +423,63 @@ class QuizMachine: return main_test_accuracy + ###################################################################### + def renew_w_quizzes(self, model, nb, for_train=True): input = model.train_w_quizzes if for_train else model.test_w_quizzes nb = min(nb, input.size(0)) input[:-nb] = input[nb:].clone() fresh_w_quizzes = self.generate_token_sequences(nb) self.reverse_random_half_in_place(fresh_w_quizzes) - input[-nb:] = fresh_w_quizzes.to(self.device) + input[-nb:] = fresh_w_quizzes.to("cpu") + + ###################################################################### def store_c_quizzes(self, new_c_quizzes, for_train=True): - if for_train: - self.train_c_quizzes.append(new_c_quizzes) - else: - self.test_c_quizzes.append(new_c_quizzes) + with self.LOCK_C_QUIZZES: + if for_train: + self.train_c_quizzes.append(new_c_quizzes.to("cpu")) + else: + self.test_c_quizzes.append(new_c_quizzes.to("cpu")) - def logproba_solution(self, models, c_quizzes): - logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models)) + def save_c_quizzes(self, filename): + torch.save((self.train_c_quizzes, self.test_c_quizzes), filename) - for model in models: - for input, l in zip( - c_quizzes.split(self.batch_size), logproba.split(self.batch_size) - ): - ar_mask = self.make_ar_mask(input) - output = model(mygpt.BracketedSequence(input)).x - ce = ( - F.cross_entropy(output.transpose(1, 2), input, reduction="none") - * ar_mask - ) - l[:, model.id] = -ce.sum(dim=-1) + def load_c_quizzes(self, filename): + self.train_c_quizzes, self.test_c_quizzes = torch.load(filename) + + ###################################################################### + + def solution_token_logprobas(self, models, c_quizzes): + logproba = c_quizzes.new_zeros( + c_quizzes.size(0), + len(models), + c_quizzes.size(1), + device=self.device, + dtype=torch.float32, + ) - return logproba + for model in models: + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, l in zip( + c_quizzes.split(self.batch_size), logproba.split(self.batch_size) + ): + input = input.to(self.device) + ar_mask = self.make_ar_mask(input) + output = model(mygpt.BracketedSequence(input)).x + l[:, model.id] = ( + -F.cross_entropy( + output.transpose(1, 2), input, reduction="none" + ) + * ar_mask + ) + + model.train(t) + + return logproba.to("cpu") ############################################################### @@ -558,4 +608,4 @@ class QuizMachine: device=self.device, ) - return c_quizzes + return c_quizzes.to("cpu")