Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 22 Jul 2024 07:02:05 +0000 (09:02 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 22 Jul 2024 07:02:05 +0000 (09:02 +0200)
quiz_machine.py

index c73b6d0..a5f9a89 100755 (executable)
@@ -529,8 +529,13 @@ class QuizMachine:
 
         seq_logproba = torch.zeros(nb, device=self.device)
 
-        def heater(T):
-            return lambda s, logits: logits / T
+        lt_noisy = lambda s, logits: logits / temperature_hot
+        lt_clean = lambda s, logits: logits / temperature_cold
+
+        # lt_noisy = lambda s, logits: logits / (
+        # 1 + 4 * (torch.rand(logits.size(), device=logits.device) < 1e-2).long()
+        # )
+        # lt_clean = None
 
         if p2a_only:
             c_quizzes[...] = self.problem.token_forward
@@ -541,7 +546,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
                 seq_logproba=seq_logproba,
-                logit_transformer=heater(temperature_hot),
+                logit_transformer=lt_noisy,
                 deterministic_synthesis=False,
                 device=self.device,
             )
@@ -552,7 +557,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
                 seq_logproba=seq_logproba,
-                logit_transformer=heater(temperature_cold),
+                logit_transformer=lt_clean,
                 deterministic_synthesis=False,
                 device=self.device,
             )
@@ -566,7 +571,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"),
                 seq_logproba=seq_logproba,
-                logit_transformer=heater(temperature_hot),
+                logit_transformer=lt_noisy,
                 deterministic_synthesis=False,
                 device=self.device,
             )
@@ -577,7 +582,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
                 seq_logproba=seq_logproba,
-                logit_transformer=heater(temperature_cold),
+                logit_transformer=lt_clean,
                 deterministic_synthesis=False,
                 device=self.device,
             )
@@ -590,7 +595,7 @@ class QuizMachine:
                 input=c_quizzes,
                 ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"),
                 seq_logproba=seq_logproba,
-                logit_transformer=heater(temperature_cold),
+                logit_transformer=lt_clean,
                 deterministic_synthesis=False,
                 device=self.device,
             )