X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=quizz_machine.py;h=f799bf1c52a133dd6c2970ae80ac1f835143d45c;hb=2186d96fccfc525884f1b3fb722c40642891ab0a;hp=28b94d10dbc990430f1a7587b11645da14b36206;hpb=c8979c695ad584c54d605b8f183e5d2e99f2d1cc;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 28b94d1..f799bf1 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -27,7 +27,7 @@ def masked_inplace_autoregression( deterministic_synthesis, forbidden_tokens=None, logit_biases=None, - progress_bar_desc="autoregression", + progress_bar_desc=None, device=torch.device("cpu"), ): assert input.size() == ar_mask.size() @@ -66,8 +66,6 @@ def masked_inplace_autoregression( ###################################################################### -import sky - class QuizzMachine: def make_ar_mask(self, input): @@ -76,6 +74,7 @@ class QuizzMachine: def __init__( self, + problem, nb_train_samples, nb_test_samples, batch_size, @@ -85,7 +84,7 @@ class QuizzMachine: ): super().__init__() - self.problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2) + self.problem = problem self.batch_size = batch_size self.device = device @@ -99,7 +98,7 @@ class QuizzMachine: if result_dir is not None: self.problem.save_quizzes( - self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger + self.train_w_quizzes[:72], result_dir, "culture_w_quizzes" ) def batches(self, split="train", desc=None): @@ -207,10 +206,7 @@ class QuizzMachine: ) self.problem.save_quizzes( - result[:72], - result_dir, - f"culture_prediction_{n_epoch:04d}_{model.id:02d}", - logger, + result[:72], result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}" ) return main_test_accuracy @@ -229,13 +225,13 @@ class QuizzMachine: def create_c_quizzes( self, + nb, + model_for_generation, + models_for_validation, + min_ave_seq_logproba, n_epoch, result_dir, logger, - nb, - model, - other_models, - min_ave_seq_logproba, ): ############################################################### # Generate quizzes with model @@ -254,37 +250,35 @@ class QuizzMachine: seq_logproba[...] = 0 masked_inplace_autoregression( - model=model, + model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, ar_mask=ar_mask, seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=False, - progress_bar_desc="sampling c_quizzes", + # progress_bar_desc="sampling c_quizzes", device=self.device, ) ave_seq_logproba = seq_logproba.mean() - logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}") - if min_ave_seq_logproba is None: break # Oh man that's ugly - if ave_seq_logproba < min_ave_seq_logproba * 1.1: + if ave_seq_logproba < min_ave_seq_logproba: if d_temperature > 0: d_temperature *= -1 / 3 temperature += d_temperature - elif ave_seq_logproba > min_ave_seq_logproba: + elif ave_seq_logproba > min_ave_seq_logproba * 0.99: if d_temperature < 0: d_temperature *= -1 / 3 temperature += d_temperature else: break - logger(f"chaging temperature to {temperature}") + logger(f"changing temperature to {temperature}") ############################################################### # Create the reverse quizzes @@ -309,18 +303,18 @@ class QuizzMachine: nb_correct = [] - for m in other_models: + for model in models_for_validation: result = c_quizzes.clone() masked_inplace_autoregression( - model=m, + model=model, batch_size=self.batch_size, input=result, ar_mask=ar_mask, seq_logproba=seq_logproba, temperature=1.0, deterministic_synthesis=True, - progress_bar_desc="solving c_quizzes", + # progress_bar_desc="solving c_quizzes", device=self.device, ) @@ -329,14 +323,14 @@ class QuizzMachine: reverse_result = reverse_c_quizzes.clone() masked_inplace_autoregression( - model=m, + model=model, batch_size=self.batch_size, input=reverse_result, ar_mask=ar_mask, seq_logproba=seq_logproba, temperature=1.0, deterministic_synthesis=True, - progress_bar_desc="solving reversed c_quizzes", + # progress_bar_desc="solving reversed c_quizzes", device=self.device, )