X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quizz_machine.py;h=2cc6cfd59d02fcb417834a34455a1f42b7dfcbd5;hb=a2346746c9b417eaf97aad87ed31dea92c3bb887;hp=28b94d10dbc990430f1a7587b11645da14b36206;hpb=c8979c695ad584c54d605b8f183e5d2e99f2d1cc;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 28b94d1..2cc6cfd 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -66,8 +66,6 @@ def masked_inplace_autoregression( ###################################################################### -import sky - class QuizzMachine: def make_ar_mask(self, input): @@ -76,6 +74,7 @@ class QuizzMachine: def __init__( self, + problem, nb_train_samples, nb_test_samples, batch_size, @@ -85,7 +84,7 @@ class QuizzMachine: ): super().__init__() - self.problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2) + self.problem = problem self.batch_size = batch_size self.device = device @@ -99,7 +98,7 @@ class QuizzMachine: if result_dir is not None: self.problem.save_quizzes( - self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger + self.train_w_quizzes[:72], result_dir, "culture_w_quizzes" ) def batches(self, split="train", desc=None): @@ -207,10 +206,7 @@ class QuizzMachine: ) self.problem.save_quizzes( - result[:72], - result_dir, - f"culture_prediction_{n_epoch:04d}_{model.id:02d}", - logger, + result[:72], result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}" ) return main_test_accuracy @@ -267,17 +263,15 @@ class QuizzMachine: ave_seq_logproba = seq_logproba.mean() - logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}") - if min_ave_seq_logproba is None: break # Oh man that's ugly - if ave_seq_logproba < min_ave_seq_logproba * 1.1: + if ave_seq_logproba < min_ave_seq_logproba: if d_temperature > 0: d_temperature *= -1 / 3 temperature += d_temperature - elif ave_seq_logproba > min_ave_seq_logproba: + elif ave_seq_logproba > min_ave_seq_logproba * 0.99: if d_temperature < 0: d_temperature *= -1 / 3 temperature += d_temperature