From f9aee903b896ae73dafe0cce1dcc40d8a39accb0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 30 Jun 2024 08:10:22 +0300 Subject: [PATCH] Update. --- quizz_machine.py | 18 ++++++++++++++++-- wireworld.py | 27 ++++++++++++++------------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/quizz_machine.py b/quizz_machine.py index 49e7835..c5870d0 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -378,7 +378,9 @@ class QuizzMachine: nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 ) - ar_mask = torch.full(c_quizzes.size(), 1, device=self.device) + ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device) + ar_mask_prompt[:, ar_mask_prompt.size(1) // 2 + 1] = 1 + ar_mask_solve = 1 - ar_mask_prompt seq_logproba = torch.empty(ar_mask.size(0), device=self.device) # bracketing of the temperature to get the target logproba @@ -393,7 +395,7 @@ class QuizzMachine: model=model_for_generation, batch_size=self.batch_size, input=c_quizzes, - ar_mask=ar_mask, + ar_mask=ar_mask_prompt, seq_logproba=seq_logproba, temperature=temperature, deterministic_synthesis=False, @@ -403,6 +405,18 @@ class QuizzMachine: ave_seq_logproba = seq_logproba.mean() + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=ar_mask_solve, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=True, + # progress_bar_desc="sampling c_quizzes", + device=self.device, + ) + # If we do not have target logprobs, get out now if min_ave_seq_logproba is None: break diff --git a/wireworld.py b/wireworld.py index 9e7d513..65b12ad 100755 --- a/wireworld.py +++ b/wireworld.py @@ -121,9 +121,9 @@ class Wireworld(problem.Problem): weight = torch.full((1, 1, 3, 3), 1.0) - # mask = (torch.rand(result[:, 0].size()) < 0.01).long() - # rand = torch.randint(4, mask.size()) - # result[:, 0] = mask * rand + (1 - mask) * result[:, 0] + mask = (torch.rand(result[:, 0].size()) < 0.01).long() + rand = torch.randint(4, mask.size()) + result[:, 0] = mask * rand + (1 - mask) * result[:, 0] # empty->empty # head->tail @@ -314,7 +314,7 @@ class Wireworld(problem.Problem): if __name__ == "__main__": import time - wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5) + wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1) start_time = time.perf_counter() frame_sequences = wireworld.generate_frame_sequences(nb=96) @@ -323,19 +323,20 @@ if __name__ == "__main__": # print(wireworld.seq2str(seq[:4])) - # for t in range(frame_sequences.size(1)): - # img = wireworld.seq2img(frame_sequences[:, t]) - # torchvision.utils.save_image( - # img.float() / 255.0, - # f"/tmp/frame_{t:03d}.png", - # nrow=8, - # padding=6, - # pad_value=0, - # ) + for t in range(frame_sequences.size(1)): + img = wireworld.seq2img(frame_sequences[:, t]) + torchvision.utils.save_image( + img.float() / 255.0, + f"/tmp/frame_{t:03d}.png", + nrow=8, + padding=6, + pad_value=0, + ) # m = (torch.rand(seq.size()) < 0.05).long() # seq = (1 - m) * seq + m * 23 + wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5) token_sequences = wireworld.generate_token_sequences(32) wireworld.save_quizzes(token_sequences, "/tmp", "seq") # img = wireworld.seq2img(frame_sequences[:60]) -- 2.39.5