seq_logproba = torch.zeros(nb, device=self.device)
- def heater(T):
- return lambda s, logits: logits / T
+ lt_noisy = lambda s, logits: logits / temperature_hot
+ lt_clean = lambda s, logits: logits / temperature_cold
+
+ # lt_noisy = lambda s, logits: logits / (
+ # 1 + 4 * (torch.rand(logits.size(), device=logits.device) < 1e-2).long()
+ # )
+ # lt_clean = None
if p2a_only:
c_quizzes[...] = self.problem.token_forward
input=c_quizzes,
ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
seq_logproba=seq_logproba,
- logit_transformer=heater(temperature_hot),
+ logit_transformer=lt_noisy,
deterministic_synthesis=False,
device=self.device,
)
input=c_quizzes,
ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
seq_logproba=seq_logproba,
- logit_transformer=heater(temperature_cold),
+ logit_transformer=lt_clean,
deterministic_synthesis=False,
device=self.device,
)
input=c_quizzes,
ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
seq_logproba=seq_logproba,
- logit_transformer=heater(temperature_hot),
+ logit_transformer=lt_noisy,
deterministic_synthesis=False,
device=self.device,
)
input=c_quizzes,
ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
seq_logproba=seq_logproba,
- logit_transformer=heater(temperature_cold),
+ logit_transformer=lt_clean,
deterministic_synthesis=False,
device=self.device,
)
input=c_quizzes,
ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
seq_logproba=seq_logproba,
- logit_transformer=heater(temperature_cold),
+ logit_transformer=lt_clean,
deterministic_synthesis=False,
device=self.device,
)