X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=quizz_machine.py;h=d63855c486673eb86211bad50781397e61a9ffd0;hb=e2c3b8046c3fddef8aacb74cf5f848d42044897e;hp=28b94d10dbc990430f1a7587b11645da14b36206;hpb=c8979c695ad584c54d605b8f183e5d2e99f2d1cc;p=culture.git diff --git a/quizz_machine.py b/quizz_machine.py index 28b94d1..d63855c 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -66,8 +66,6 @@ def masked_inplace_autoregression( ###################################################################### -import sky - class QuizzMachine: def make_ar_mask(self, input): @@ -76,6 +74,7 @@ class QuizzMachine: def __init__( self, + problem, nb_train_samples, nb_test_samples, batch_size, @@ -85,7 +84,7 @@ class QuizzMachine: ): super().__init__() - self.problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2) + self.problem = problem self.batch_size = batch_size self.device = device @@ -267,17 +266,15 @@ class QuizzMachine: ave_seq_logproba = seq_logproba.mean() - logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}") - if min_ave_seq_logproba is None: break # Oh man that's ugly - if ave_seq_logproba < min_ave_seq_logproba * 1.1: + if ave_seq_logproba < min_ave_seq_logproba: if d_temperature > 0: d_temperature *= -1 / 3 temperature += d_temperature - elif ave_seq_logproba > min_ave_seq_logproba: + elif ave_seq_logproba > min_ave_seq_logproba * 0.99: if d_temperature < 0: d_temperature *= -1 / 3 temperature += d_temperature