)
ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device)
- ar_mask_prompt[:, ar_mask_prompt.size(1) // 2 + 1] = 1
+ ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1
ar_mask_solve = 1 - ar_mask_prompt
- seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
+ seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
# bracketing of the temperature to get the target logproba
- temperature = 1
+ warnings.warn("high temperature!", RuntimeWarning)
+ temperature = 2
d_temperature = 1 / 3
while True: