nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
- ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
+ ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device)
+ ar_mask_prompt[:, ar_mask_prompt.size(1) // 2 + 1] = 1
+ ar_mask_solve = 1 - ar_mask_prompt
seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
# bracketing of the temperature to get the target logproba
model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
- ar_mask=ar_mask,
+ ar_mask=ar_mask_prompt,
seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=False,
ave_seq_logproba = seq_logproba.mean()
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_solve,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=True,
+ # progress_bar_desc="sampling c_quizzes",
+ device=self.device,
+ )
+
# If we do not have target logprobs, get out now
if min_ave_seq_logproba is None:
break
weight = torch.full((1, 1, 3, 3), 1.0)
- # mask = (torch.rand(result[:, 0].size()) < 0.01).long()
- # rand = torch.randint(4, mask.size())
- # result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
+ mask = (torch.rand(result[:, 0].size()) < 0.01).long()
+ rand = torch.randint(4, mask.size())
+ result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
# empty->empty
# head->tail
if __name__ == "__main__":
import time
- wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
+ wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1)
start_time = time.perf_counter()
frame_sequences = wireworld.generate_frame_sequences(nb=96)
# print(wireworld.seq2str(seq[:4]))
- # for t in range(frame_sequences.size(1)):
- # img = wireworld.seq2img(frame_sequences[:, t])
- # torchvision.utils.save_image(
- # img.float() / 255.0,
- # f"/tmp/frame_{t:03d}.png",
- # nrow=8,
- # padding=6,
- # pad_value=0,
- # )
+ for t in range(frame_sequences.size(1)):
+ img = wireworld.seq2img(frame_sequences[:, t])
+ torchvision.utils.save_image(
+ img.float() / 255.0,
+ f"/tmp/frame_{t:03d}.png",
+ nrow=8,
+ padding=6,
+ pad_value=0,
+ )
# m = (torch.rand(seq.size()) < 0.05).long()
# seq = (1 - m) * seq + m * 23
+ wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
token_sequences = wireworld.generate_token_sequences(32)
wireworld.save_quizzes(token_sequences, "/tmp", "seq")
# img = wireworld.seq2img(frame_sequences[:60])