######################################################################
-import sky
-
class QuizzMachine:
def make_ar_mask(self, input):
def __init__(
self,
+ problem,
nb_train_samples,
nb_test_samples,
batch_size,
):
super().__init__()
- self.problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2)
+ self.problem = problem
self.batch_size = batch_size
self.device = device
ave_seq_logproba = seq_logproba.mean()
- logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}")
-
if min_ave_seq_logproba is None:
break
# Oh man that's ugly
- if ave_seq_logproba < min_ave_seq_logproba * 1.1:
+ if ave_seq_logproba < min_ave_seq_logproba:
if d_temperature > 0:
d_temperature *= -1 / 3
temperature += d_temperature
- elif ave_seq_logproba > min_ave_seq_logproba:
+ elif ave_seq_logproba > min_ave_seq_logproba * 0.99:
if d_temperature < 0:
d_temperature *= -1 / 3
temperature += d_temperature