projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
quizz_machine.py
diff --git
a/quizz_machine.py
b/quizz_machine.py
index
239dc68
..
bf36d0b
100755
(executable)
--- a/
quizz_machine.py
+++ b/
quizz_machine.py
@@
-311,7
+311,6
@@
class QuizzMachine:
self.test_c_quizzes.append(new_c_quizzes)
def comput_correctness(self, c_quizzes, models_for_validation):
self.test_c_quizzes.append(new_c_quizzes)
def comput_correctness(self, c_quizzes, models_for_validation):
- ###############################################################
# Create the reverse quizzes
token_forward, token_backward = self.problem.direction_tokens()
# Create the reverse quizzes
token_forward, token_backward = self.problem.direction_tokens()
@@
-328,11
+327,9
@@
class QuizzMachine:
ar_mask = self.make_ar_mask(c_quizzes)
seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
ar_mask = self.make_ar_mask(c_quizzes)
seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
- ###############################################################
- # Check how many of the other models can solve them in both
- # directions
+ # Check how many of models can solve the quizzes in both directions
- nb_correct =
[]
+ nb_correct =
0
for model in models_for_validation:
result = c_quizzes.clone()
for model in models_for_validation:
result = c_quizzes.clone()
@@
-369,14
+366,13
@@
class QuizzMachine:
(reverse_c_quizzes == reverse_result).long().min(dim=-1).values
)
(reverse_c_quizzes == reverse_result).long().min(dim=-1).values
)
- nb_correct
.append((correct * reverse_correct)[None, :])
+ nb_correct
+= correct * reverse_correct
- return
torch.cat(nb_correct, dim=0).sum(dim=0)
+ return
nb_correct
- def generate_quizzes(self, nb, model_for_generation, min_ave_seq_logproba):
- ###############################################################
- # Generate quizzes with model
+ ###############################################################
+ def generate_quizzes(self, nb, model_for_generation, min_ave_seq_logproba):
c_quizzes = torch.empty(
nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
c_quizzes = torch.empty(
nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)