X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quiz_machine.py;h=631d41bcb46939277a36cab2bdf7fb95aeac8d24;hb=3ba9d1e0d85d689c2bdea9d2d571c6e8851a55b5;hp=cdfba85e6a2abb8d6cd1b14bb3b0f76fe2afad30;hpb=693af34e144cd20d2dde6a508a190d49c1a76c7f;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index cdfba85..631d41b 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -15,6 +15,8 @@ from torch.nn import functional as F import mygpt from mygpt import BracketedSequence +import threading + ###################################################################### # ar_mask is a tensor with 0s and 1s, of same shape as input, with @@ -235,22 +237,10 @@ class QuizMachine: self.prompt_len = None self.answer_len = None - # self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) - # self.reverse_random_half_in_place(self.train_w_quizzes) - - # self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) - # self.reverse_random_half_in_place(self.test_w_quizzes) - + self.LOCK_C_QUIZZES = threading.Lock() self.train_c_quizzes = [] self.test_c_quizzes = [] - # if result_dir is not None: - # self.save_quizzes( - # result_dir, - # "culture_w_quizzes", - # self.train_w_quizzes[:72], - # ) - def save_quizzes( self, result_dir, @@ -285,34 +275,41 @@ class QuizMachine: predicted_answers, ) + def vocabulary_size(self): + return self.nb_token_values + + ###################################################################### + def batches(self, model, split="train", desc=None): assert split in {"train", "test"} - if split == "train": - w_quizzes = model.train_w_quizzes - c_quizzes = self.train_c_quizzes - else: - w_quizzes = model.test_w_quizzes - c_quizzes = self.test_c_quizzes - - if len(c_quizzes) > 0: - c_quizzes = torch.cat(c_quizzes, dim=0) - if c_quizzes.size(0) > w_quizzes.size(0) // 2: - i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] - c_quizzes = c_quizzes[i] - i = torch.randperm(w_quizzes.size(0))[ - : w_quizzes.size(0) - c_quizzes.size(0) - ] - w_quizzes = w_quizzes[i] + with self.LOCK_C_QUIZZES: + if split == "train": + w_quizzes = model.train_w_quizzes + c_quizzes = self.train_c_quizzes + else: + w_quizzes = model.test_w_quizzes + c_quizzes = self.test_c_quizzes + + if len(c_quizzes) > 0: + c_quizzes = torch.cat(c_quizzes, dim=0) + if c_quizzes.size(0) > w_quizzes.size(0) // 2: + i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] + c_quizzes = c_quizzes[i] + + i = torch.randperm(w_quizzes.size(0))[ + : w_quizzes.size(0) - c_quizzes.size(0) + ] + w_quizzes = w_quizzes[i] - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = c_quizzes.size(0) + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = c_quizzes.size(0) - input = torch.cat([w_quizzes, c_quizzes], dim=0) - else: - input = w_quizzes - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = 0 + input = torch.cat([w_quizzes, c_quizzes], dim=0) + else: + input = w_quizzes + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = 0 # Shuffle input = input[torch.randperm(input.size(0))] @@ -324,13 +321,13 @@ class QuizMachine: ): yield batch - def vocabulary_size(self): - return self.nb_token_values + ###################################################################### def produce_results( self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 ): def compute_accuracy(input, log_prefix=None): + input = input.to(self.device) ar_mask = self.make_ar_mask(input) result = input.clone() * (1 - ar_mask) seq_logproba = torch.empty(input.size(0), device=self.device) @@ -371,11 +368,7 @@ class QuizMachine: backward_nb_total = correct[n_backward].size(0) self.logger( - f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)" - ) - - self.logger( - f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)" + f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}" ) return result, correct @@ -400,36 +393,52 @@ class QuizMachine: return main_test_accuracy + ###################################################################### + def renew_w_quizzes(self, model, nb, for_train=True): input = model.train_w_quizzes if for_train else model.test_w_quizzes nb = min(nb, input.size(0)) input[:-nb] = input[nb:].clone() fresh_w_quizzes = self.generate_token_sequences(nb) self.reverse_random_half_in_place(fresh_w_quizzes) - input[-nb:] = fresh_w_quizzes.to(self.device) + input[-nb:] = fresh_w_quizzes.to("cpu") + + ###################################################################### def store_c_quizzes(self, new_c_quizzes, for_train=True): - if for_train: - self.train_c_quizzes.append(new_c_quizzes) - else: - self.test_c_quizzes.append(new_c_quizzes) + with self.LOCK_C_QUIZZES: + if for_train: + self.train_c_quizzes.append(new_c_quizzes.to("cpu")) + else: + self.test_c_quizzes.append(new_c_quizzes.to("cpu")) - def logproba_solution(self, models, c_quizzes): - logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models)) + ###################################################################### - for model in models: - for input, l in zip( - c_quizzes.split(self.batch_size), logproba.split(self.batch_size) - ): - ar_mask = self.make_ar_mask(input) - output = model(mygpt.BracketedSequence(input)).x - ce = ( - F.cross_entropy(output.transpose(1, 2), input, reduction="none") - * ar_mask - ) - l[:, model.id] = -ce.sum(dim=-1) + def logproba_of_solutions(self, models, c_quizzes): + logproba = c_quizzes.new_zeros( + c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32 + ) - return logproba + for model in models: + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, l in zip( + c_quizzes.split(self.batch_size), logproba.split(self.batch_size) + ): + input = input.to(self.device) + ar_mask = self.make_ar_mask(input) + output = model(mygpt.BracketedSequence(input)).x + ce = ( + F.cross_entropy(output.transpose(1, 2), input, reduction="none") + * ar_mask + ) + l[:, model.id] = -ce.sum(dim=-1) + + model.train(t) + + return logproba.to("cpu") ############################################################### @@ -558,4 +567,4 @@ class QuizMachine: device=self.device, ) - return c_quizzes + return c_quizzes.to("cpu")