projects
/
culture.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
09c5eea
)
Update.
author
François Fleuret
<francois@fleuret.org>
Tue, 2 Jul 2024 10:07:46 +0000
(13:07 +0300)
committer
François Fleuret
<francois@fleuret.org>
Tue, 2 Jul 2024 10:07:46 +0000
(13:07 +0300)
quizz_machine.py
patch
|
blob
|
history
diff --git
a/quizz_machine.py
b/quizz_machine.py
index
1a20563
..
5807b66
100755
(executable)
--- a/
quizz_machine.py
+++ b/
quizz_machine.py
@@
-395,8
+395,6
@@
class QuizzMachine:
nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
- c_quizzes[:, 0] = self.token_forward
-
ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device)
ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1
ar_mask_second = 1 - ar_mask_first
ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device)
ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1
ar_mask_second = 1 - ar_mask_first
@@
-406,16
+404,13
@@
class QuizzMachine:
seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device)
if reverse_cleanup:
seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device)
if reverse_cleanup:
- warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
temperature = 10.0
else:
temperature = 1.0
temperature = 10.0
else:
temperature = 1.0
- # warnings.warn("noise injection", RuntimeWarning)
- # noise_std = torch.rand(1).item()
- # self.logger(f"{noise_std=}")
+ # First, we generate the answer at high temperature
- # mygpt.set_noise_injection(model_for_generation, noise_std)
+ c_quizzes[:, 0] = self.token_backward
masked_inplace_autoregression(
model=model_for_generation,
masked_inplace_autoregression(
model=model_for_generation,
@@
-428,10
+423,10
@@
class QuizzMachine:
device=self.device,
)
device=self.device,
)
- # mygpt.set_noise_injection(model_for_generation, 0.0)
-
ave_seq_logproba = seq_logproba.mean()
ave_seq_logproba = seq_logproba.mean()
+ # Then, we generate the prompt deterministically
+
masked_inplace_autoregression(
model=model_for_generation,
batch_size=self.batch_size,
masked_inplace_autoregression(
model=model_for_generation,
batch_size=self.batch_size,
@@
-443,31
+438,20
@@
class QuizzMachine:
device=self.device,
)
device=self.device,
)
- if reverse_cleanup:
- c_quizzes = self.reverse_time(c_quizzes)
+ # Then we return the quizz, and re-generate the response, now
+ # deterministically
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=ar_mask_second,
- seq_logproba=seq_logproba,
- temperature=temperature,
- deterministic_synthesis=True,
- device=self.device,
- )
-
- c_quizzes = self.reverse_time(c_quizzes)
+ c_quizzes = self.reverse_time(c_quizzes)
-
masked_inplace_autoregression(
-
model=model_for_generation,
-
batch_size=self.batch_size,
-
input=c_quizzes,
-
ar_mask=ar_mask_second,
-
seq_logproba=seq_logproba,
-
temperature=temperature,
-
deterministic_synthesis=True,
-
device=self.device,
-
)
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_second,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=True,
+ device=self.device,
+ )
return c_quizzes, seq_logproba.mean()
return c_quizzes, seq_logproba.mean()