Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 30 Jun 2024 05:10:22 +0000 (08:10 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 30 Jun 2024 05:10:22 +0000 (08:10 +0300)
quizz_machine.py
wireworld.py

index 49e7835..c5870d0 100755 (executable)
@@ -378,7 +378,9 @@ class QuizzMachine:
             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
         )
 
-        ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
+        ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device)
+        ar_mask_prompt[:, ar_mask_prompt.size(1) // 2 + 1] = 1
+        ar_mask_solve = 1 - ar_mask_prompt
         seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
 
         # bracketing of the temperature to get the target logproba
@@ -393,7 +395,7 @@ class QuizzMachine:
                 model=model_for_generation,
                 batch_size=self.batch_size,
                 input=c_quizzes,
-                ar_mask=ar_mask,
+                ar_mask=ar_mask_prompt,
                 seq_logproba=seq_logproba,
                 temperature=temperature,
                 deterministic_synthesis=False,
@@ -403,6 +405,18 @@ class QuizzMachine:
 
             ave_seq_logproba = seq_logproba.mean()
 
+            masked_inplace_autoregression(
+                model=model_for_generation,
+                batch_size=self.batch_size,
+                input=c_quizzes,
+                ar_mask=ar_mask_solve,
+                seq_logproba=seq_logproba,
+                temperature=temperature,
+                deterministic_synthesis=True,
+                # progress_bar_desc="sampling c_quizzes",
+                device=self.device,
+            )
+
             # If we do not have target logprobs, get out now
             if min_ave_seq_logproba is None:
                 break
index 9e7d513..65b12ad 100755 (executable)
@@ -121,9 +121,9 @@ class Wireworld(problem.Problem):
 
         weight = torch.full((1, 1, 3, 3), 1.0)
 
-        mask = (torch.rand(result[:, 0].size()) < 0.01).long()
-        rand = torch.randint(4, mask.size())
-        result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
+        mask = (torch.rand(result[:, 0].size()) < 0.01).long()
+        rand = torch.randint(4, mask.size())
+        result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
 
         # empty->empty
         # head->tail
@@ -314,7 +314,7 @@ class Wireworld(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
+    wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1)
 
     start_time = time.perf_counter()
     frame_sequences = wireworld.generate_frame_sequences(nb=96)
@@ -323,19 +323,20 @@ if __name__ == "__main__":
 
     # print(wireworld.seq2str(seq[:4]))
 
-    for t in range(frame_sequences.size(1)):
-    # img = wireworld.seq2img(frame_sequences[:, t])
-    # torchvision.utils.save_image(
-    # img.float() / 255.0,
-    # f"/tmp/frame_{t:03d}.png",
-    # nrow=8,
-    # padding=6,
-    # pad_value=0,
-    # )
+    for t in range(frame_sequences.size(1)):
+        img = wireworld.seq2img(frame_sequences[:, t])
+        torchvision.utils.save_image(
+            img.float() / 255.0,
+            f"/tmp/frame_{t:03d}.png",
+            nrow=8,
+            padding=6,
+            pad_value=0,
+        )
 
     # m = (torch.rand(seq.size()) < 0.05).long()
     # seq = (1 - m) * seq + m * 23
 
+    wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
     token_sequences = wireworld.generate_token_sequences(32)
     wireworld.save_quizzes(token_sequences, "/tmp", "seq")
     # img = wireworld.seq2img(frame_sequences[:60])