projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
quizz_machine.py
diff --git
a/quizz_machine.py
b/quizz_machine.py
index
28b94d1
..
2cc6cfd
100755
(executable)
--- a/
quizz_machine.py
+++ b/
quizz_machine.py
@@
-66,8
+66,6
@@
def masked_inplace_autoregression(
######################################################################
######################################################################
-import sky
-
class QuizzMachine:
def make_ar_mask(self, input):
class QuizzMachine:
def make_ar_mask(self, input):
@@
-76,6
+74,7
@@
class QuizzMachine:
def __init__(
self,
def __init__(
self,
+ problem,
nb_train_samples,
nb_test_samples,
batch_size,
nb_train_samples,
nb_test_samples,
batch_size,
@@
-85,7
+84,7
@@
class QuizzMachine:
):
super().__init__()
):
super().__init__()
- self.problem =
sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2)
+ self.problem =
problem
self.batch_size = batch_size
self.device = device
self.batch_size = batch_size
self.device = device
@@
-99,7
+98,7
@@
class QuizzMachine:
if result_dir is not None:
self.problem.save_quizzes(
if result_dir is not None:
self.problem.save_quizzes(
- self.train_w_quizzes[:72], result_dir,
f"culture_w_quizzes", logger
+ self.train_w_quizzes[:72], result_dir,
"culture_w_quizzes"
)
def batches(self, split="train", desc=None):
)
def batches(self, split="train", desc=None):
@@
-207,10
+206,7
@@
class QuizzMachine:
)
self.problem.save_quizzes(
)
self.problem.save_quizzes(
- result[:72],
- result_dir,
- f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
- logger,
+ result[:72], result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}"
)
return main_test_accuracy
)
return main_test_accuracy
@@
-267,17
+263,15
@@
class QuizzMachine:
ave_seq_logproba = seq_logproba.mean()
ave_seq_logproba = seq_logproba.mean()
- logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}")
-
if min_ave_seq_logproba is None:
break
# Oh man that's ugly
if min_ave_seq_logproba is None:
break
# Oh man that's ugly
- if ave_seq_logproba < min_ave_seq_logproba
* 1.1
:
+ if ave_seq_logproba < min_ave_seq_logproba:
if d_temperature > 0:
d_temperature *= -1 / 3
temperature += d_temperature
if d_temperature > 0:
d_temperature *= -1 / 3
temperature += d_temperature
- elif ave_seq_logproba > min_ave_seq_logproba:
+ elif ave_seq_logproba > min_ave_seq_logproba
* 0.99
:
if d_temperature < 0:
d_temperature *= -1 / 3
temperature += d_temperature
if d_temperature < 0:
d_temperature *= -1 / 3
temperature += d_temperature