X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quizz_machine.py;h=f799bf1c52a133dd6c2970ae80ac1f835143d45c;hb=e5efa329be244007e11013af84be1f448a04e1a0;hp=d8ebad80f2ceea399e0ee7558b8c14d79f5b8e94;hpb=952dabe800dba1bb7bb295e3600022ea2fba0b66;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index d8ebad8..f799bf1 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -27,7 +27,7 @@ def masked_inplace_autoregression( deterministic_synthesis, forbidden_tokens=None, logit_biases=None, - progress_bar_desc="autoregression", + progress_bar_desc=None, device=torch.device("cpu"), ): assert input.size() == ar_mask.size() @@ -67,40 +67,14 @@ def masked_inplace_autoregression( ###################################################################### -class Task: - def batches(self, split="train", nb_to_use=-1, desc=None): - pass - - def vocabulary_size(self): - pass - - def produce_results( - self, n_epoch, model, result_dir, logger, deterministic_synthesis - ): - pass - - -###################################################################### - -import sky - - -class QuizzMachine(Task): - def save_image(self, input, result_dir, filename, logger): - img = sky.seq2img(input.to("cpu"), self.height, self.width) - image_name = os.path.join(result_dir, filename) - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) - logger(f"wrote {image_name}") - - def save_quizzes(self, input, result_dir, filename_prefix, logger): - self.save_image(input, result_dir, filename_prefix + ".png", logger) - +class QuizzMachine: def make_ar_mask(self, input): b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2 return b.long()[None, :].expand_as(input) def __init__( self, + problem, nb_train_samples, nb_test_samples, batch_size, @@ -110,18 +84,12 @@ class QuizzMachine(Task): ): super().__init__() + self.problem = problem self.batch_size = batch_size self.device = device - self.height = 6 - self.width = 8 - - self.train_w_quizzes = sky.generate_seq( - nb_train_samples, height=self.height, width=self.width - ).to(device) - self.test_w_quizzes = sky.generate_seq( - nb_test_samples, height=self.height, width=self.width - ).to(device) + self.train_w_quizzes = self.problem.generate_seq(nb_train_samples).to(device) + self.test_w_quizzes = self.problem.generate_seq(nb_test_samples).to(device) self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1 @@ -129,8 +97,8 @@ class QuizzMachine(Task): self.test_c_quizzes = [] if result_dir is not None: - self.save_quizzes( - self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger + self.problem.save_quizzes( + self.train_w_quizzes[:72], result_dir, "culture_w_quizzes" ) def batches(self, split="train", desc=None): @@ -237,11 +205,8 @@ class QuizzMachine(Task): device=self.device, ) - self.save_quizzes( - result[:72], - result_dir, - f"culture_prediction_{n_epoch:04d}_{model.id:02d}", - logger, + self.problem.save_quizzes( + result[:72], result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}" ) return main_test_accuracy @@ -250,9 +215,7 @@ class QuizzMachine(Task): input = self.train_w_quizzes if for_train else self.test_w_quizzes nb = min(nb, input.size(0)) input[:-nb] = input[nb:].clone() - input[-nb:] = sky.generate_seq(nb, height=self.height, width=self.width).to( - self.device - ) + input[-nb:] = self.problem.generate_seq(nb).to(self.device) def store_c_quizzes(self, new_c_quizzes, for_train=True): if for_train: @@ -262,71 +225,71 @@ class QuizzMachine(Task): def create_c_quizzes( self, + nb, + model_for_generation, + models_for_validation, + min_ave_seq_logproba, n_epoch, result_dir, logger, - nb, - model, - other_models, - min_ave_seq_logproba, ): ############################################################### # Generate quizzes with model c_quizzes = torch.empty( - nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 + nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 ) ar_mask = torch.full(c_quizzes.size(), 1, device=self.device) seq_logproba = torch.empty(ar_mask.size(0), device=self.device) temperature = 1 - d_temperature = 1 + d_temperature = 1 / 3 while True: seq_logproba[...] = 0 masked_inplace_autoregression( - model=model, + model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, ar_mask=ar_mask, seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=False, - progress_bar_desc="sampling c_quizzes", + # progress_bar_desc="sampling c_quizzes", device=self.device, ) ave_seq_logproba = seq_logproba.mean() - logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}") - if min_ave_seq_logproba is None: break # Oh man that's ugly - if ave_seq_logproba < min_ave_seq_logproba * 1.1: + if ave_seq_logproba < min_ave_seq_logproba: if d_temperature > 0: d_temperature *= -1 / 3 temperature += d_temperature - elif ave_seq_logproba > min_ave_seq_logproba: + elif ave_seq_logproba > min_ave_seq_logproba * 0.99: if d_temperature < 0: d_temperature *= -1 / 3 temperature += d_temperature else: break - logger(f"chaging temperature to {temperature}") + logger(f"changing temperature to {temperature}") ############################################################### # Create the reverse quizzes - l = self.height * self.width + token_forward, token_backward = self.problem.direction_tokens() + + l = (c_quizzes.size(1) - 1) // 2 direction = c_quizzes[:, l : l + 1] - direction = sky.token_forward * ( - direction == sky.token_backward - ) + sky.token_backward * (direction == sky.token_forward) + direction = self.problem.token_forward * ( + direction == self.problem.token_backward + ) + self.problem.token_backward * (direction == self.problem.token_forward) reverse_c_quizzes = torch.cat( [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1 ) @@ -340,18 +303,18 @@ class QuizzMachine(Task): nb_correct = [] - for m in other_models: + for model in models_for_validation: result = c_quizzes.clone() masked_inplace_autoregression( - model=m, + model=model, batch_size=self.batch_size, input=result, ar_mask=ar_mask, seq_logproba=seq_logproba, temperature=1.0, deterministic_synthesis=True, - progress_bar_desc="solving c_quizzes", + # progress_bar_desc="solving c_quizzes", device=self.device, ) @@ -360,14 +323,14 @@ class QuizzMachine(Task): reverse_result = reverse_c_quizzes.clone() masked_inplace_autoregression( - model=m, + model=model, batch_size=self.batch_size, input=reverse_result, ar_mask=ar_mask, seq_logproba=seq_logproba, temperature=1.0, deterministic_synthesis=True, - progress_bar_desc="solving reversed c_quizzes", + # progress_bar_desc="solving reversed c_quizzes", device=self.device, )