projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
quizz_machine.py
diff --git
a/quizz_machine.py
b/quizz_machine.py
index
717e8ac
..
92b5799
100755
(executable)
--- a/
quizz_machine.py
+++ b/
quizz_machine.py
@@
-238,7
+238,7
@@
class QuizzMachine:
result_dir,
"culture_w_quizzes",
self.train_w_quizzes[:72],
result_dir,
"culture_w_quizzes",
self.train_w_quizzes[:72],
-
show_to_be_predicted=True
,
+
n_backward=self.train_w_quizzes[:72, 0] == self.token_backward
,
)
def save_quizzes(
)
def save_quizzes(
@@
-246,7
+246,7
@@
class QuizzMachine:
result_dir,
filename_prefix,
quizzes,
result_dir,
filename_prefix,
quizzes,
-
show_to_be_predicted=Fals
e,
+
n_backward=Non
e,
mistakes=None,
):
quizzes = quizzes.clone()
mistakes=None,
):
quizzes = quizzes.clone()
@@
-256,8
+256,11
@@
class QuizzMachine:
assert forward.size(0) + backward.size(0) == quizzes.size(0)
quizzes[ib] = self.reverse_time(quizzes[ib])
assert forward.size(0) + backward.size(0) == quizzes.size(0)
quizzes[ib] = self.reverse_time(quizzes[ib])
- if show_to_be_predicted:
- predicted_prompts = ib.long()
+ if n_backward is None:
+ predicted_prompts = None
+ predicted_answers = None
+ else:
+ predicted_prompts = n_backward.long()
predicted_answers = 1 - predicted_prompts
if mistakes is not None:
# 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
predicted_answers = 1 - predicted_prompts
if mistakes is not None:
# 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
@@
-267,9
+270,6
@@
class QuizzMachine:
# 0/2 ~ not-to-predict / to predict
predicted_prompts *= 2
predicted_answers *= 2
# 0/2 ~ not-to-predict / to predict
predicted_prompts *= 2
predicted_answers *= 2
- else:
- predicted_prompts = None
- predicted_answers = None
self.problem.save_quizzes(
result_dir,
self.problem.save_quizzes(
result_dir,
@@
-360,28
+360,28
@@
class QuizzMachine:
result[n_backward], correct[n_backward] = compute_accuracy(back_input)
if log_prefix is not None:
result[n_backward], correct[n_backward] = compute_accuracy(back_input)
if log_prefix is not None:
- nb_correct = correct[n_forward].sum()
- nb_total = correct[n_forward].size(0)
- back_nb_correct = correct[n_backward].sum()
- back_nb_total = correct[n_backward].size(0)
+
forward_
nb_correct = correct[n_forward].sum()
+
forward_
nb_total = correct[n_forward].size(0)
+ back
ward
_nb_correct = correct[n_backward].sum()
+ back
ward
_nb_total = correct[n_backward].size(0)
self.logger(
self.logger(
- f"
accuracy {log_prefix} {n_epoch} {model.id=} {nb_correct} / {
nb_total}"
+ f"
forward_accuracy {log_prefix} {n_epoch} {model.id=} {forward_nb_correct} / {forward_
nb_total}"
)
self.logger(
)
self.logger(
- f"back
_accuracy {log_prefix} {n_epoch} {model.id=} {back_nb_correct} / {back
_nb_total}"
+ f"back
ward_accuracy {log_prefix} {n_epoch} {model.id=} {backward_nb_correct} / {backward
_nb_total}"
)
return result, correct
compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train")
)
return result, correct
compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train")
-
result,
correct = compute_accuracy(
+
test_result, test_
correct = compute_accuracy(
self.test_w_quizzes[:nmax], log_prefix="test"
)
self.test_w_quizzes[:nmax], log_prefix="test"
)
- main_test_accuracy =
correct.sum() /
correct.size(0)
+ main_test_accuracy =
test_correct.sum() / test_
correct.size(0)
self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
##############################
self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
##############################
@@
-389,9
+389,9
@@
class QuizzMachine:
self.save_quizzes(
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
self.save_quizzes(
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
- quizzes=result[:72],
-
show_to_be_predicted=True
,
- mistakes=correct[:72] * 2 - 1,
+ quizzes=
test_
result[:72],
+
n_backward=self.test_w_quizzes[:72, 0] == self.token_backward
,
+ mistakes=
test_
correct[:72] * 2 - 1,
)
return main_test_accuracy
)
return main_test_accuracy