projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
quizz_machine.py
diff --git
a/quizz_machine.py
b/quizz_machine.py
index
5807b66
..
0d6d8f5
100755
(executable)
--- a/
quizz_machine.py
+++ b/
quizz_machine.py
@@
-333,7
+333,7
@@
class QuizzMachine:
)
def compute_correctness(
)
def compute_correctness(
- self, c_quizzes, models_for_validation, both_directions=
Tru
e
+ self, c_quizzes, models_for_validation, both_directions=
Fals
e
):
reversed_c_quizzes = self.reverse_time(c_quizzes)
):
reversed_c_quizzes = self.reverse_time(c_quizzes)
@@
-390,7
+390,7
@@
class QuizzMachine:
###############################################################
###############################################################
- def generate_quizzes(self, nb, model_for_generation
, reverse_cleanup=False
):
+ def generate_quizzes(self, nb, model_for_generation):
c_quizzes = torch.empty(
nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
c_quizzes = torch.empty(
nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
@@
-403,10
+403,7
@@
class QuizzMachine:
seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device)
seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device)
- if reverse_cleanup:
- temperature = 10.0
- else:
- temperature = 1.0
+ temperature = 10.0
# First, we generate the answer at high temperature
# First, we generate the answer at high temperature
@@
-433,7
+430,7
@@
class QuizzMachine:
input=c_quizzes,
ar_mask=ar_mask_second,
seq_logproba=seq_logproba,
input=c_quizzes,
ar_mask=ar_mask_second,
seq_logproba=seq_logproba,
- temperature=
temperature
,
+ temperature=
1.0
,
deterministic_synthesis=True,
device=self.device,
)
deterministic_synthesis=True,
device=self.device,
)