Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 2 Jul 2024 16:33:15 +0000 (19:33 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 2 Jul 2024 16:33:15 +0000 (19:33 +0300)
main.py
quizz_machine.py
sky.py

diff --git a/main.py b/main.py
index 7b8b642..918f75d 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -93,6 +93,12 @@ parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
+parser.add_argument("--generation_temperature", type=float, default=1.0)
+
+parser.add_argument("--stochastic_validation", action="store_true", default=False)
+
+######################################################################
+
 parser.add_argument("--sky_height", type=int, default=6)
 
 parser.add_argument("--sky_width", type=int, default=8)
@@ -411,10 +417,14 @@ def create_c_quizzes(
             c_quizzes = quizz_machine.generate_quizzes(
                 nb_to_create,
                 model_for_generation=model_for_generation,
+                temperature=args.generation_temperature,
             )
 
             nb_correct, seq_logproba = quizz_machine.compute_correctness(
-                c_quizzes, models, both_directions=args.both_directions
+                c_quizzes,
+                models,
+                both_directions=args.both_directions,
+                deterministic_validation=not args.stochastic_validation,
             )
 
             for n, l in zip(nb_correct, seq_logproba):
index 470b095..9b64941 100755 (executable)
@@ -322,7 +322,11 @@ class QuizzMachine:
         )
 
     def compute_correctness(
-        self, c_quizzes, models_for_validation, both_directions=False
+        self,
+        c_quizzes,
+        models_for_validation,
+        both_directions=False,
+        deterministic_validation=True,
     ):
         reversed_c_quizzes = self.reverse_time(c_quizzes)
 
@@ -349,7 +353,7 @@ class QuizzMachine:
                 ar_mask=ar_mask,
                 seq_logproba=seq_logproba[:, model.id],
                 temperature=1.0,
-                deterministic_synthesis=True,
+                deterministic_synthesis=deterministic_validation,
                 # progress_bar_desc="solving c_quizzes",
                 device=self.device,
             )
@@ -366,7 +370,7 @@ class QuizzMachine:
                     ar_mask=ar_mask,
                     seq_logproba=seq_logproba[:, model.id],
                     temperature=1.0,
-                    deterministic_synthesis=True,
+                    deterministic_synthesis=deterministic_validation,
                     # progress_bar_desc="solving reversed c_quizzes",
                     device=self.device,
                 )
@@ -385,7 +389,7 @@ class QuizzMachine:
 
     ###############################################################
 
-    def generate_quizzes(self, nb, model_for_generation):
+    def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
         c_quizzes = torch.empty(
             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
         )
@@ -398,8 +402,6 @@ class QuizzMachine:
 
         seq_logproba = torch.zeros(ar_mask_first.size(0), device=self.device)
 
-        temperature = 10.0
-
         # First, we generate the answer at high temperature
 
         c_quizzes[:, 0] = self.token_backward
@@ -415,7 +417,7 @@ class QuizzMachine:
             device=self.device,
         )
 
-        # Then, we generate the prompt deterministically
+        # Then, we generate the prompt at low temperature
 
         masked_inplace_autoregression(
             model=model_for_generation,
@@ -423,13 +425,13 @@ class QuizzMachine:
             input=c_quizzes,
             ar_mask=ar_mask_second,
             seq_logproba=seq_logproba,
-            temperature=1.0,
-            deterministic_synthesis=True,
+            temperature=1 / temperature,
+            deterministic_synthesis=False,
             device=self.device,
         )
 
         # Then we return the quizz, and re-generate the response, now
-        # deterministically
+        # at low temperature
 
         c_quizzes = self.reverse_time(c_quizzes)
 
@@ -439,8 +441,8 @@ class QuizzMachine:
             input=c_quizzes,
             ar_mask=ar_mask_second,
             seq_logproba=seq_logproba,
-            temperature=temperature,
-            deterministic_synthesis=True,
+            temperature=1 / temperature,
+            deterministic_synthesis=False,
             device=self.device,
         )
 
diff --git a/sky.py b/sky.py
index 2183cf1..040ec67 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -42,9 +42,6 @@ class Sky(problem.Problem):
         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
     )
 
-    def nb_token_values(self):
-        return len(self.colors)
-
     def __init__(
         self,
         height=6,
@@ -155,15 +152,6 @@ class Sky(problem.Problem):
 
     ######################################################################
 
-    def generate_prompts_and_answers(self, nb):
-        frame_sequences = self.generate_frame_sequences(nb)
-        frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
-        prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
-        answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
-        return prompts, answers
-
-    ######################################################################
-
     def frame2img(self, x, scale=15):
         x = x.reshape(x.size(0), self.height, -1)
         m = torch.logical_and(
@@ -250,6 +238,18 @@ class Sky(problem.Problem):
             img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0
         )
 
+    ######################################################################
+
+    def nb_token_values(self):
+        return len(self.colors)
+
+    def generate_prompts_and_answers(self, nb):
+        frame_sequences = self.generate_frame_sequences(nb)
+        frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
+        prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
+        answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
+        return prompts, answers
+
     def save_quizzes(
         self,
         result_dir,