X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quiz_machine.py;h=ae146147ca1f753f327943b1494811081d261be1;hb=4a63b2b44bc08cb04b236b35a3d36aa242912d48;hp=c1477c9bb497498ddbe5aa4b2cc898ba1a915796;hpb=f7a8cd141a39039048d3e9311220a33079f2cfc7;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index c1477c9..ae14614 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -15,6 +15,8 @@ from torch.nn import functional as F import mygpt from mygpt import BracketedSequence +import threading + ###################################################################### # ar_mask is a tensor with 0s and 1s, of same shape as input, with @@ -235,24 +237,10 @@ class QuizMachine: self.prompt_len = None self.answer_len = None - self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) - self.reverse_random_half_in_place(self.train_w_quizzes) - self.train_w_quizzes = self.train_w_quizzes.to(device) - - self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) - self.reverse_random_half_in_place(self.test_w_quizzes) - self.test_w_quizzes = self.test_w_quizzes.to(device) - + self.LOCK_C_QUIZZES = threading.Lock() self.train_c_quizzes = [] self.test_c_quizzes = [] - if result_dir is not None: - self.save_quizzes( - result_dir, - "culture_w_quizzes", - self.train_w_quizzes[:72], - ) - def save_quizzes( self, result_dir, @@ -287,34 +275,41 @@ class QuizMachine: predicted_answers, ) - def batches(self, split="train", desc=None): - assert split in {"train", "test"} - if split == "train": - w_quizzes = self.train_w_quizzes - c_quizzes = self.train_c_quizzes - else: - w_quizzes = self.test_w_quizzes - c_quizzes = self.test_c_quizzes + def vocabulary_size(self): + return self.nb_token_values - if len(c_quizzes) > 0: - c_quizzes = torch.cat(c_quizzes, dim=0) - if c_quizzes.size(0) > w_quizzes.size(0) // 2: - i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] - c_quizzes = c_quizzes[i] + ###################################################################### - i = torch.randperm(w_quizzes.size(0))[ - : w_quizzes.size(0) - c_quizzes.size(0) - ] - w_quizzes = w_quizzes[i] + def batches(self, model, split="train", desc=None): + assert split in {"train", "test"} - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = c_quizzes.size(0) + with self.LOCK_C_QUIZZES: + if split == "train": + w_quizzes = model.train_w_quizzes + c_quizzes = self.train_c_quizzes + else: + w_quizzes = model.test_w_quizzes + c_quizzes = self.test_c_quizzes + + if len(c_quizzes) > 0: + c_quizzes = torch.cat(c_quizzes, dim=0) + if c_quizzes.size(0) > w_quizzes.size(0) // 2: + i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] + c_quizzes = c_quizzes[i] + + i = torch.randperm(w_quizzes.size(0))[ + : w_quizzes.size(0) - c_quizzes.size(0) + ] + w_quizzes = w_quizzes[i] - input = torch.cat([w_quizzes, c_quizzes], dim=0) - else: - input = w_quizzes - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = 0 + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = c_quizzes.size(0) + + input = torch.cat([w_quizzes, c_quizzes], dim=0) + else: + input = w_quizzes + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = 0 # Shuffle input = input[torch.randperm(input.size(0))] @@ -326,13 +321,13 @@ class QuizMachine: ): yield batch - def vocabulary_size(self): - return self.nb_token_values + ###################################################################### def produce_results( self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 ): def compute_accuracy(input, log_prefix=None): + input = input.to(self.device) ar_mask = self.make_ar_mask(input) result = input.clone() * (1 - ar_mask) seq_logproba = torch.empty(input.size(0), device=self.device) @@ -382,10 +377,10 @@ class QuizMachine: return result, correct - compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train") + compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") test_result, test_correct = compute_accuracy( - self.test_w_quizzes[:nmax], log_prefix="test" + model.test_w_quizzes[:nmax], log_prefix="test" ) main_test_accuracy = test_correct.sum() / test_correct.size(0) @@ -402,36 +397,46 @@ class QuizMachine: return main_test_accuracy - def renew_w_quizzes(self, nb, for_train=True): - input = self.train_w_quizzes if for_train else self.test_w_quizzes + ###################################################################### + + def renew_w_quizzes(self, model, nb, for_train=True): + input = model.train_w_quizzes if for_train else model.test_w_quizzes nb = min(nb, input.size(0)) input[:-nb] = input[nb:].clone() fresh_w_quizzes = self.generate_token_sequences(nb) self.reverse_random_half_in_place(fresh_w_quizzes) - input[-nb:] = fresh_w_quizzes.to(self.device) + input[-nb:] = fresh_w_quizzes.to("cpu") + + ###################################################################### def store_c_quizzes(self, new_c_quizzes, for_train=True): - if for_train: - self.train_c_quizzes.append(new_c_quizzes) - else: - self.test_c_quizzes.append(new_c_quizzes) + with self.LOCK_C_QUIZZES: + if for_train: + self.train_c_quizzes.append(new_c_quizzes.to("cpu")) + else: + self.test_c_quizzes.append(new_c_quizzes.to("cpu")) - def logproba_solution(self, models, c_quizzes): - logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models)) + ###################################################################### + + def logproba_of_solutions(self, models, c_quizzes): + logproba = c_quizzes.new_zeros( + c_quizzes.size(0), len(models), device=self.device + ) for model in models: for input, l in zip( c_quizzes.split(self.batch_size), logproba.split(self.batch_size) ): + input = input.to(self.device) ar_mask = self.make_ar_mask(input) output = model(mygpt.BracketedSequence(input)).x ce = ( F.cross_entropy(output.transpose(1, 2), input, reduction="none") * ar_mask ) - l[:, model.id] = ce.sum(dim=-1) + l[:, model.id] = -ce.sum(dim=-1) - return logproba + return logproba.to("cpu") ############################################################### @@ -507,7 +512,10 @@ class QuizMachine: def generate_quizzes(self, nb, model_for_generation, temperature=1.0): c_quizzes = torch.empty( - nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 + nb, + self.prompt_len + self.answer_len + 2, + device=self.device, + dtype=torch.int64, ) seq_logproba = torch.zeros(nb, device=self.device) @@ -557,4 +565,4 @@ class QuizMachine: device=self.device, ) - return c_quizzes + return c_quizzes.to("cpu")