projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
quiz_machine.py
diff --git
a/quiz_machine.py
b/quiz_machine.py
index
631d41b
..
eab41dc
100755
(executable)
--- a/
quiz_machine.py
+++ b/
quiz_machine.py
@@
-241,7
+241,7
@@
class QuizMachine:
self.train_c_quizzes = []
self.test_c_quizzes = []
self.train_c_quizzes = []
self.test_c_quizzes = []
- def save_quiz
ze
s(
+ def save_quiz
_illustration
s(
self,
result_dir,
filename_prefix,
self,
result_dir,
filename_prefix,
@@
-266,7
+266,7
@@
class QuizMachine:
predicted_prompts *= 2
predicted_answers *= 2
predicted_prompts *= 2
predicted_answers *= 2
- self.problem.save_quiz
ze
s(
+ self.problem.save_quiz
_illustration
s(
result_dir,
filename_prefix,
quizzes[:, 1 : 1 + self.prompt_len],
result_dir,
filename_prefix,
quizzes[:, 1 : 1 + self.prompt_len],
@@
-373,7
+373,7
@@
class QuizMachine:
return result, correct
return result, correct
- compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
+
#
compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
test_result, test_correct = compute_accuracy(
model.test_w_quizzes[:nmax], log_prefix="test"
test_result, test_correct = compute_accuracy(
model.test_w_quizzes[:nmax], log_prefix="test"
@@
-384,7
+384,7
@@
class QuizMachine:
##############################
##############################
- self.save_quiz
ze
s(
+ self.save_quiz
_illustration
s(
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
quizzes=test_result[:72],
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
quizzes=test_result[:72],
@@
-412,6
+412,12
@@
class QuizMachine:
else:
self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
else:
self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
+ def save_c_quizzes(self, filename):
+ torch.save((self.train_c_quizzes, self.test_c_quizzes), filename)
+
+ def load_c_quizzes(self, filename):
+ self.train_c_quizzes, self.test_c_quizzes = torch.load(filename)
+
######################################################################
def logproba_of_solutions(self, models, c_quizzes):
######################################################################
def logproba_of_solutions(self, models, c_quizzes):