X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quiz_machine.py;h=1f1046dd8120a6568c811795df92a97ad38e5c87;hb=050976a525fee2d3b824350a3058ab7299a2bd3d;hp=c1477c9bb497498ddbe5aa4b2cc898ba1a915796;hpb=f7a8cd141a39039048d3e9311220a33079f2cfc7;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index c1477c9..1f1046d 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -15,6 +15,8 @@ from torch.nn import functional as F import mygpt from mygpt import BracketedSequence +import threading + ###################################################################### # ar_mask is a tensor with 0s and 1s, of same shape as input, with @@ -235,24 +237,10 @@ class QuizMachine: self.prompt_len = None self.answer_len = None - self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) - self.reverse_random_half_in_place(self.train_w_quizzes) - self.train_w_quizzes = self.train_w_quizzes.to(device) - - self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) - self.reverse_random_half_in_place(self.test_w_quizzes) - self.test_w_quizzes = self.test_w_quizzes.to(device) - + self.LOCK_C_QUIZZES = threading.Lock() self.train_c_quizzes = [] self.test_c_quizzes = [] - if result_dir is not None: - self.save_quizzes( - result_dir, - "culture_w_quizzes", - self.train_w_quizzes[:72], - ) - def save_quizzes( self, result_dir, @@ -287,34 +275,41 @@ class QuizMachine: predicted_answers, ) - def batches(self, split="train", desc=None): - assert split in {"train", "test"} - if split == "train": - w_quizzes = self.train_w_quizzes - c_quizzes = self.train_c_quizzes - else: - w_quizzes = self.test_w_quizzes - c_quizzes = self.test_c_quizzes + def vocabulary_size(self): + return self.nb_token_values - if len(c_quizzes) > 0: - c_quizzes = torch.cat(c_quizzes, dim=0) - if c_quizzes.size(0) > w_quizzes.size(0) // 2: - i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] - c_quizzes = c_quizzes[i] + ###################################################################### - i = torch.randperm(w_quizzes.size(0))[ - : w_quizzes.size(0) - c_quizzes.size(0) - ] - w_quizzes = w_quizzes[i] + def batches(self, model, split="train", desc=None): + assert split in {"train", "test"} - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = c_quizzes.size(0) + with self.LOCK_C_QUIZZES: + if split == "train": + w_quizzes = model.train_w_quizzes + c_quizzes = self.train_c_quizzes + else: + w_quizzes = model.test_w_quizzes + c_quizzes = self.test_c_quizzes + + if len(c_quizzes) > 0: + c_quizzes = torch.cat(c_quizzes, dim=0) + if c_quizzes.size(0) > w_quizzes.size(0) // 2: + i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] + c_quizzes = c_quizzes[i] + + i = torch.randperm(w_quizzes.size(0))[ + : w_quizzes.size(0) - c_quizzes.size(0) + ] + w_quizzes = w_quizzes[i] - input = torch.cat([w_quizzes, c_quizzes], dim=0) - else: - input = w_quizzes - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = 0 + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = c_quizzes.size(0) + + input = torch.cat([w_quizzes, c_quizzes], dim=0) + else: + input = w_quizzes + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = 0 # Shuffle input = input[torch.randperm(input.size(0))] @@ -326,8 +321,7 @@ class QuizMachine: ): yield batch - def vocabulary_size(self): - return self.nb_token_values + ###################################################################### def produce_results( self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 @@ -382,10 +376,10 @@ class QuizMachine: return result, correct - compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train") + compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") test_result, test_correct = compute_accuracy( - self.test_w_quizzes[:nmax], log_prefix="test" + model.test_w_quizzes[:nmax], log_prefix="test" ) main_test_accuracy = test_correct.sum() / test_correct.size(0) @@ -402,21 +396,28 @@ class QuizMachine: return main_test_accuracy - def renew_w_quizzes(self, nb, for_train=True): - input = self.train_w_quizzes if for_train else self.test_w_quizzes + ###################################################################### + + def renew_w_quizzes(self, model, nb, for_train=True): + input = model.train_w_quizzes if for_train else model.test_w_quizzes nb = min(nb, input.size(0)) input[:-nb] = input[nb:].clone() fresh_w_quizzes = self.generate_token_sequences(nb) self.reverse_random_half_in_place(fresh_w_quizzes) input[-nb:] = fresh_w_quizzes.to(self.device) + ###################################################################### + def store_c_quizzes(self, new_c_quizzes, for_train=True): - if for_train: - self.train_c_quizzes.append(new_c_quizzes) - else: - self.test_c_quizzes.append(new_c_quizzes) + with self.LOCK_C_QUIZZES: + if for_train: + self.train_c_quizzes.append(new_c_quizzes) + else: + self.test_c_quizzes.append(new_c_quizzes) - def logproba_solution(self, models, c_quizzes): + ###################################################################### + + def logproba_of_solutions(self, models, c_quizzes): logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models)) for model in models: @@ -429,7 +430,7 @@ class QuizMachine: F.cross_entropy(output.transpose(1, 2), input, reduction="none") * ar_mask ) - l[:, model.id] = ce.sum(dim=-1) + l[:, model.id] = -ce.sum(dim=-1) return logproba @@ -507,7 +508,10 @@ class QuizMachine: def generate_quizzes(self, nb, model_for_generation, temperature=1.0): c_quizzes = torch.empty( - nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 + nb, + self.prompt_len + self.answer_len + 2, + device=self.device, + dtype=torch.int64, ) seq_logproba = torch.zeros(nb, device=self.device)