X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quiz_machine.py;h=bc468d3d4cb3ce7d1aea7ef127773c4d494cdd32;hb=HEAD;hp=c39bf7adbf2fc30ddb155578ebcd7b29b245ec6f;hpb=2f87c91cf606a068de1450d198660de7e44cd356;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index c39bf7a..92da03d 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, os, tqdm, warnings +import math, os, tqdm, warnings, sys import torch, torchvision @@ -29,9 +29,11 @@ def one_batch_masked_inplace_autoregression( input, ar_mask, seq_logproba, - temperature, - deterministic_synthesis, + deterministic_synthesis=False, ): + if input.size(0) == 0: + return + to_generate = (ar_mask.sum(0) > 0).nonzero() if to_generate.min() > 0: @@ -43,8 +45,6 @@ def one_batch_masked_inplace_autoregression( logits = output[:, s] - logits = (logits / temperature).log_softmax(dim=-1) - if deterministic_synthesis: t_next = logits.argmax(-1) else: @@ -58,229 +58,99 @@ def one_batch_masked_inplace_autoregression( input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] -def masked_inplace_autoregression( - model, - batch_size, - input, - ar_mask, - seq_logproba, - temperature, - deterministic_synthesis, - forbidden_tokens=None, - logit_biases=None, - progress_bar_desc=None, - device=torch.device("cpu"), -): - assert input.size() == ar_mask.size() - - batches = zip( - input.split(batch_size), - ar_mask.split(batch_size), - seq_logproba.split(batch_size), - ) - - if progress_bar_desc is not None: - batches = tqdm.tqdm( - batches, - dynamic_ncols=True, - desc=progress_bar_desc, - total=(input.size(0) + batch_size - 1) // batch_size, - ) - - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, ar_mask, seq_logproba in batches: - one_batch_masked_inplace_autoregression( - model=model, - input=input, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=deterministic_synthesis, - ) - - model.train(t) - - ###################################################################### class QuizMachine: - def indices_forward_and_backward(self, quizzes): - i_forward = quizzes[:, 0] == self.token_forward - j_forward = quizzes[:, 1 + self.prompt_len] == self.token_forward - i_backward = quizzes[:, 0] == self.token_backward - j_backward = quizzes[:, 1 + self.answer_len] == self.token_backward - assert torch.logical_or( - torch.logical_and(i_forward, j_forward), - torch.logical_and(i_backward, j_backward), - ).all() - return i_forward, i_backward - - def non_trivial(self, quizzes): - quizzes = quizzes.clone() - n_forward = quizzes[quizzes[:, 0] == self.token_forward] - n_backward = quizzes[:, 0] == self.token_backward - backward = quizzes[n_backward] - quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) - return torch.logical_not( - self.problem.trivial_prompts_and_answers( - quizzes[:, 1 : 1 + self.prompt_len], - quizzes[:, 2 + self.prompt_len :], - ) - ) - - def reverse_time(self, quizzes): - i_forward, i_backward = self.indices_forward_and_backward(quizzes) - - forward_to_backward = torch.cat( - [ - quizzes[:, 0:1], - quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len], - quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1], - quizzes[:, 1 : 1 + self.prompt_len], - ], - dim=1, - ) - - forward_to_backward[:, 0] = self.token_backward - forward_to_backward[:, 1 + self.answer_len] = self.token_backward - - backward_to_forward = torch.cat( - [ - quizzes[:, 0:1], - quizzes[:, 2 + self.answer_len :], - quizzes[:, 1 + self.answer_len : 2 + self.answer_len], - quizzes[:, 1 : 1 + self.answer_len], - ], - dim=1, - ) - - backward_to_forward[:, 0] = self.token_forward - backward_to_forward[:, 1 + self.prompt_len] = self.token_forward - - m = i_forward.long()[:, None] - - return m * forward_to_backward + (1 - m) * backward_to_forward - - def reverse_random_half_in_place(self, quizzes): - i = torch.rand(quizzes.size(0)) < 0.5 - if i.any(): - quizzes[i] = self.reverse_time(quizzes[i]) - - def make_ar_mask(self, quizzes, first=False): - i_forward, i_backward = self.indices_forward_and_backward(quizzes) - - t = torch.arange(quizzes.size(1), device=quizzes.device) - - if first: - m_forward = (t >= 1).long() * (t < 1 + self.prompt_len).long() - m_backward = (t >= 1).long() * (t < 1 + self.answer_len).long() - else: - m_forward = (t >= 2 + self.prompt_len).long() - m_backward = (t >= 2 + self.answer_len).long() - - m = i_forward.long()[:, None] - - return m * m_forward + (1 - m) * m_backward - - def generate_token_sequences(self, nb): - prompts, answers = self.problem.generate_prompts_and_answers(nb) - - if self.prompt_len is None: - self.prompt_len = prompts.size(1) - - if self.answer_len is None: - self.answer_len = answers.size(1) - - assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len - - result = [] - - for prompt, answer in zip(prompts, answers): - a = [ - torch.tensor([self.token_forward]), - prompt, - torch.tensor([self.token_forward]), - answer, - ] - - result.append(torch.cat(a, dim=0)[None, :]) - - return torch.cat(result, dim=0) - def __init__( self, problem, - nb_train_samples, - nb_test_samples, - back_accuracy, batch_size, result_dir, + prompt_noise, logger, device=torch.device("cpu"), ): super().__init__() - v = problem.nb_token_values() - self.token_forward = v - self.token_backward = v + 1 - self.nb_token_values = v + 2 - self.problem = problem - self.back_accuracy = back_accuracy self.batch_size = batch_size self.device = device self.logger = logger self.prompt_len = None self.answer_len = None + self.prompt_noise = prompt_noise + + # struct, mask_generate, mask_noise, mask_loss + self.train_structures = [ + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), + (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), + (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), + # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)), + # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), + # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)), + # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), + # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), + ] + + self.test_structures = self.train_structures self.LOCK_C_QUIZZES = threading.Lock() self.train_c_quizzes = [] self.test_c_quizzes = [] - def save_quiz_illustrations( + def vocabulary_size(self): + return self.problem.nb_token_values + + ###################################################################### + + def autoregression( self, - result_dir, - filename_prefix, - quizzes, - mistakes=None, + model, + input, + ar_mask, + seq_logproba=None, + progress_bar_desc=None, ): - quizzes = quizzes.clone().to("cpu") - n_forward = quizzes[quizzes[:, 0] == self.token_forward] - n_backward = quizzes[:, 0] == self.token_backward - backward = quizzes[n_backward] - assert n_forward.size(0) + backward.size(0) == quizzes.size(0) - quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) - - predicted_prompts = n_backward.long() - predicted_answers = 1 - predicted_prompts - if mistakes is not None: - # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct - predicted_prompts *= mistakes.to("cpu") - predicted_answers *= mistakes.to("cpu") - else: - # 0/2 ~ not-to-predict / to predict - predicted_prompts *= 2 - predicted_answers *= 2 + assert input.size() == ar_mask.size() - self.problem.save_quiz_illustrations( - result_dir, - filename_prefix, - quizzes[:, 1 : 1 + self.prompt_len], - quizzes[:, 2 + self.prompt_len :], - predicted_prompts, - predicted_answers, + if seq_logproba is None: + seq_logproba = torch.empty(input.size(0), device=self.device) + + batches = zip( + input.split(self.batch_size), + ar_mask.split(self.batch_size), + seq_logproba.split(self.batch_size), ) - def vocabulary_size(self): - return self.nb_token_values + if progress_bar_desc is not None: + batches = tqdm.tqdm( + batches, + dynamic_ncols=True, + desc=progress_bar_desc, + total=(input.size(0) + self.batch_size - 1) // self.batch_size, + ) + + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, ar_mask, seq_logproba in batches: + one_batch_masked_inplace_autoregression( + model=model, + input=input, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + deterministic_synthesis=False, + ) + + model.train(t) ###################################################################### - def batches(self, model, split="train", desc=None): + def data_input(self, model, split="train"): assert split in {"train", "test"} with self.LOCK_C_QUIZZES: @@ -293,6 +163,7 @@ class QuizMachine: if len(c_quizzes) > 0: c_quizzes = torch.cat(c_quizzes, dim=0) + if c_quizzes.size(0) > w_quizzes.size(0) // 2: i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] c_quizzes = c_quizzes[i] @@ -302,106 +173,154 @@ class QuizMachine: ] w_quizzes = w_quizzes[i] - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = c_quizzes.size(0) + quizzes = torch.cat([w_quizzes, c_quizzes], dim=0) + from_w = torch.arange( + quizzes.size(0), device=quizzes.device + ) < w_quizzes.size(0) - input = torch.cat([w_quizzes, c_quizzes], dim=0) else: - input = w_quizzes - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = 0 + quizzes = w_quizzes.clone() + from_w = torch.full((quizzes.size(0),), True, device=quizzes.device) + + i = torch.randperm(quizzes.size(0), device=quizzes.device) + quizzes, from_w = quizzes[i], from_w[i] - # Shuffle - input = input[torch.randperm(input.size(0))] + self.randomize_configuations_inplace( + quizzes, structs=[s for s, _, _, _ in self.train_structures] + ) + + quiz_mask_loss = quizzes.new_full(quizzes.size(), 1) + + if self.prompt_noise > 0.0: + for struct, _, mask_noise, mask_loss in self.train_structures: + i = self.problem.indices_select(quizzes=quizzes, struct=struct) + if i.any(): + quizzes[i] = self.problem.inject_noise( + quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise + ) + quiz_mask_loss[i] = self.make_quiz_mask( + quizzes=quizzes[i], struct=struct, mask=mask_loss + ) - if desc is None: - desc = f"epoch-{split}" - for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=desc - ): - yield batch + return quizzes, quiz_mask_loss ###################################################################### - def produce_results( - self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 - ): - def compute_accuracy(input, log_prefix=None): - input = input.to(self.device) - ar_mask = self.make_ar_mask(input) - result = input.clone() * (1 - ar_mask) - seq_logproba = torch.empty(input.size(0), device=self.device) + def make_quiz_mask(self, quizzes, struct, mask): + assert struct in [s for s, _, _, _ in self.train_structures] + return self.problem.make_quiz_mask(quizzes, struct=struct, mask=mask) - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - temperature=1.0, - deterministic_synthesis=deterministic_synthesis, - progress_bar_desc=None, - device=self.device, - ) + ###################################################################### - correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device) + def predict(self, model, quizzes, struct, mask): + ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask) + result = quizzes * (1 - ar_mask) - n_forward = input[:, 0] == self.token_forward - n_backward = input[:, 0] == self.token_backward + seq_logproba = torch.empty(quizzes.size(0), device=self.device) - correct[n_forward] = ( - (input[n_forward] == result[n_forward]).long().min(dim=1).values - ) + self.autoregression( + model=model, + input=result, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + progress_bar_desc="accuracy", + ) - if self.back_accuracy and n_backward.any(): - # accuracy of B->A*->B*=B instead of B->A*=A - back_input = self.reverse_time(result[n_backward]) - back_input[:, 2 + self.prompt_len :] = input[ - n_backward, 1 : 1 + self.answer_len - ] - _, correct[n_backward] = compute_accuracy(back_input) + correct = (result == quizzes).min(dim=1).values.long() - if log_prefix is not None: - forward_nb_correct = correct[n_forward].sum() - forward_nb_total = correct[n_forward].size(0) - backward_nb_correct = correct[n_backward].sum() - backward_nb_total = correct[n_backward].size(0) + return result, correct - self.logger( - f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}" - ) + ###################################################################### + + def produce_results(self, n_epoch, model, input, result_dir): + input = input.to(self.device) + result = input.new(input.size()) + correct = input.new(input.size(0)) + predicted_parts = input.new(input.size(0), 4) + + nb = 0 - return result, correct + # We consider all the configurations that we train for + for struct, mask_generate, _, _ in self.test_structures: + i = self.problem.indices_select(quizzes=input, struct=struct) + nb += i.long().sum() + result[i], correct[i] = self.predict( + model=model, quizzes=input[i], struct=struct, mask=mask_generate + ) + predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[ + None, : + ] + solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1 + correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long() - compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") + assert nb == input.size(0) - test_result, test_correct = compute_accuracy( - model.test_w_quizzes[:nmax], log_prefix="test" + nb_correct = (correct == 1).long().sum() + nb_total = (correct != 0).long().sum() + self.logger( + f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}" ) - main_test_accuracy = test_correct.sum() / test_correct.size(0) - self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}") + main_test_accuracy = nb_correct / nb_total ############################## - self.save_quiz_illustrations( + correct_parts = predicted_parts * correct[:, None] + + result = result[:128] + predicted_parts = predicted_parts[:128] + correct_parts = correct_parts[:128] + + self.problem.save_quizzes_as_image( result_dir, - f"culture_prediction_{n_epoch:04d}_{model.id:02d}", - quizzes=test_result[:72], - mistakes=test_correct[:72] * 2 - 1, + f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png", + quizzes=result, + predicted_parts=predicted_parts, + correct_parts=correct_parts, ) return main_test_accuracy ###################################################################### - def renew_w_quizzes(self, model, nb, for_train=True): - input = model.train_w_quizzes if for_train else model.test_w_quizzes - nb = min(nb, input.size(0)) - input[:-nb] = input[nb:].clone() - fresh_w_quizzes = self.generate_token_sequences(nb) - self.reverse_random_half_in_place(fresh_w_quizzes) - input[-nb:] = fresh_w_quizzes.to("cpu") + def randomize_configuations_inplace(self, quizzes, structs): + r = torch.randint(len(structs), (quizzes.size(0),), device=quizzes.device) + for c in range(len(structs)): + quizzes[r == c] = self.problem.reconfigure( + quizzes[r == c], struct=structs[c] + ) + + ###################################################################### + + def renew_train_w_quizzes(self, model): + if hasattr(model, "hard_w_quizzes"): + hard_w_quizzes = self.problem.reconfigure( + model.hard_w_quizzes, struct=("A", "f_A", "B", "f_B") + ) + self.logger( + f"re-using {hard_w_quizzes.size(0)} hard world quizzes from model {model.id}" + ) + if hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0): + nb_to_generate = 0 + model.train_w_quizzes[...] = hard_w_quizzes[ + torch.randperm(hard_w_quizzes.size(0))[ + model.train_w_quizzes.size(0) + ] + ] + else: + nb_to_generate = model.train_w_quizzes.size(0) - hard_w_quizzes.size(0) + model.train_w_quizzes[...] = torch.cat( + [ + hard_w_quizzes, + self.problem.generate_w_quizzes(nb_to_generate), + ], + dim=0, + ) + else: + nb_to_generate = 0 + model.train_w_quizzes[...] = self.problem.generate_w_quizzes( + model.train_w_quizzes.size(0) + ) ###################################################################### @@ -420,157 +339,92 @@ class QuizMachine: ###################################################################### - def logproba_of_solutions(self, models, c_quizzes): - logproba = c_quizzes.new_zeros( - c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32 - ) - - for model in models: - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, l in zip( - c_quizzes.split(self.batch_size), logproba.split(self.batch_size) - ): - input = input.to(self.device) - ar_mask = self.make_ar_mask(input) - output = model(mygpt.BracketedSequence(input)).x - ce = ( - F.cross_entropy(output.transpose(1, 2), input, reduction="none") - * ar_mask - ) - l[:, model.id] = -ce.sum(dim=-1) - - model.train(t) - - return logproba.to("cpu") - - ############################################################### - - def compute_correctness( + def models_logprobas( self, - c_quizzes, models_for_validation, - bidirectional_validation=False, - deterministic_validation=True, + c_quizzes, + struct, + mask_loss, + mask_noise=None, + device=None, ): - if bidirectional_validation: - backward_c_quizzes = self.forward_to_backward(c_quizzes) + if device is None: + device = self.device + + c_quizzes = self.problem.reconfigure(c_quizzes, struct) seq_logproba = torch.zeros( c_quizzes.size(0), max([m.id for m in models_for_validation]) + 1, - device=self.device, + device=device, ) - nb_correct = 0 - - seq_logproba[...] = 0.0 + # if self.prompt_noise > 0.0 and mask_noise is not None: + # c_quizzes = self.problem.inject_noise( + # c_quizzes, self.prompt_noise, struct=struct, mask=mask_noise + # ) for model in models_for_validation: - result = c_quizzes.clone() - - ar_mask = self.make_ar_mask(result) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - seq_logproba=seq_logproba[:, model.id], - temperature=1.0, - deterministic_synthesis=deterministic_validation, - # progress_bar_desc="solving c_quizzes", - device=self.device, - ) - - correct = (c_quizzes == result).long().min(dim=-1).values - - if bidirectional_validation: - backward_result = backward_c_quizzes.clone() - - ar_mask = self.make_ar_mask(backward_result) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=backward_result, - ar_mask=ar_mask, - seq_logproba=seq_logproba[:, model.id], - temperature=1.0, - deterministic_synthesis=deterministic_validation, - # progress_bar_desc="solving backward c_quizzes", - device=self.device, - ) - - backward_correct = ( - (backward_c_quizzes == backward_result).long().min(dim=-1).values - ) - - correct *= backward_correct - - # endif + with torch.autograd.no_grad(): + t = model.training + model.eval() - nb_correct += correct + for input, l in zip( + c_quizzes.split(self.batch_size), + seq_logproba.split(self.batch_size), + ): + input = input.to(device) + quiz_mask_loss = self.make_quiz_mask( + input, struct=struct, mask=mask_loss + ) + output = model(mygpt.BracketedSequence(input)).x + l[:, model.id] = ( + -F.cross_entropy( + output.transpose(1, 2), input, reduction="none" + ) + * quiz_mask_loss + ).sum(dim=1) - return nb_correct, seq_logproba + model.train(t) - ############################################################### + return seq_logproba.to("cpu") - def generate_quizzes(self, nb, model_for_generation, temperature=1.0): - c_quizzes = torch.empty( - nb, - self.prompt_len + self.answer_len + 2, - device=self.device, - dtype=torch.int64, - ) + ###################################################################### + def generate_c_quizzes(self, nb, model_for_generation, procedure, recorder=None): seq_logproba = torch.zeros(nb, device=self.device) - # First, we generate the answer at high temperature - - c_quizzes[:, 0] = self.token_backward - c_quizzes[:, 1 + self.answer_len] = self.token_backward + c_quizzes = None - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes, first=True), - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=False, - device=self.device, - ) + for s, m, mt in procedure: + if c_quizzes is None: + c_quizzes = self.problem.create_empty_quizzes(nb, s) + c_quizzes = c_quizzes.to(self.device) + elif s != pred_s: + c_quizzes = self.problem.reconfigure(c_quizzes, s) + pred_s = s - # Then, we generate the prompt at low temperature + if mt is not None: + mt(model_for_generation) - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes), - seq_logproba=seq_logproba, - temperature=1 / temperature, - deterministic_synthesis=False, - device=self.device, - ) + self.autoregression( + model=model_for_generation, + input=c_quizzes, + ar_mask=self.make_quiz_mask(c_quizzes, s, m), + seq_logproba=seq_logproba, + ) - # Then we return the quizz, and re-generate the response, now - # at low temperature + model_for_generation.reset_transformations() - c_quizzes = self.reverse_time(c_quizzes) + if recorder is not None: + x = c_quizzes.clone() + t = torch.tensor(m, device=x.device)[None, :].expand(x.size(0), -1) + recorder.append( + self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B")) + ) - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes), - seq_logproba=seq_logproba, - temperature=1 / temperature, - deterministic_synthesis=False, - device=self.device, - ) + c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) return c_quizzes.to("cpu") + + ######################################################################