projects
/
culture.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
0c13a50
)
Update.
author
François Fleuret
<francois@fleuret.org>
Sat, 6 Jul 2024 04:37:52 +0000
(07:37 +0300)
committer
François Fleuret
<francois@fleuret.org>
Sat, 6 Jul 2024 04:37:52 +0000
(07:37 +0300)
quizz_machine.py
patch
|
blob
|
history
diff --git
a/quizz_machine.py
b/quizz_machine.py
index
c6c2f95
..
92b5799
100755
(executable)
--- a/
quizz_machine.py
+++ b/
quizz_machine.py
@@
-238,7
+238,7
@@
class QuizzMachine:
result_dir,
"culture_w_quizzes",
self.train_w_quizzes[:72],
result_dir,
"culture_w_quizzes",
self.train_w_quizzes[:72],
-
show_to_be_predicted=True
,
+
n_backward=self.train_w_quizzes[:72, 0] == self.token_backward
,
)
def save_quizzes(
)
def save_quizzes(
@@
-246,7
+246,7
@@
class QuizzMachine:
result_dir,
filename_prefix,
quizzes,
result_dir,
filename_prefix,
quizzes,
-
show_to_be_predicted=Fals
e,
+
n_backward=Non
e,
mistakes=None,
):
quizzes = quizzes.clone()
mistakes=None,
):
quizzes = quizzes.clone()
@@
-256,8
+256,11
@@
class QuizzMachine:
assert forward.size(0) + backward.size(0) == quizzes.size(0)
quizzes[ib] = self.reverse_time(quizzes[ib])
assert forward.size(0) + backward.size(0) == quizzes.size(0)
quizzes[ib] = self.reverse_time(quizzes[ib])
- if show_to_be_predicted:
- predicted_prompts = ib.long()
+ if n_backward is None:
+ predicted_prompts = None
+ predicted_answers = None
+ else:
+ predicted_prompts = n_backward.long()
predicted_answers = 1 - predicted_prompts
if mistakes is not None:
# 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
predicted_answers = 1 - predicted_prompts
if mistakes is not None:
# 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
@@
-267,9
+270,6
@@
class QuizzMachine:
# 0/2 ~ not-to-predict / to predict
predicted_prompts *= 2
predicted_answers *= 2
# 0/2 ~ not-to-predict / to predict
predicted_prompts *= 2
predicted_answers *= 2
- else:
- predicted_prompts = None
- predicted_answers = None
self.problem.save_quizzes(
result_dir,
self.problem.save_quizzes(
result_dir,
@@
-390,7
+390,7
@@
class QuizzMachine:
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
quizzes=test_result[:72],
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
quizzes=test_result[:72],
-
show_to_be_predicted=True
,
+
n_backward=self.test_w_quizzes[:72, 0] == self.token_backward
,
mistakes=test_correct[:72] * 2 - 1,
)
mistakes=test_correct[:72] * 2 - 1,
)