From: François Fleuret Date: Mon, 22 Jul 2024 07:02:05 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=03fcaec34516db7ac941059d6c48737d378e9ff5;p=culture.git Update. --- diff --git a/quiz_machine.py b/quiz_machine.py index c73b6d0..a5f9a89 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -529,8 +529,13 @@ class QuizMachine: seq_logproba = torch.zeros(nb, device=self.device) - def heater(T): - return lambda s, logits: logits / T + lt_noisy = lambda s, logits: logits / temperature_hot + lt_clean = lambda s, logits: logits / temperature_cold + + # lt_noisy = lambda s, logits: logits / ( + # 1 + 4 * (torch.rand(logits.size(), device=logits.device) < 1e-2).long() + # ) + # lt_clean = None if p2a_only: c_quizzes[...] = self.problem.token_forward @@ -541,7 +546,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"), seq_logproba=seq_logproba, - logit_transformer=heater(temperature_hot), + logit_transformer=lt_noisy, deterministic_synthesis=False, device=self.device, ) @@ -552,7 +557,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), seq_logproba=seq_logproba, - logit_transformer=heater(temperature_cold), + logit_transformer=lt_clean, deterministic_synthesis=False, device=self.device, ) @@ -566,7 +571,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"), seq_logproba=seq_logproba, - logit_transformer=heater(temperature_hot), + logit_transformer=lt_noisy, deterministic_synthesis=False, device=self.device, ) @@ -577,7 +582,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), seq_logproba=seq_logproba, - logit_transformer=heater(temperature_cold), + logit_transformer=lt_clean, deterministic_synthesis=False, device=self.device, ) @@ -590,7 +595,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), seq_logproba=seq_logproba, - logit_transformer=heater(temperature_cold), + logit_transformer=lt_clean, deterministic_synthesis=False, device=self.device, )