- ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device)
- ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1
- ar_mask_solve = 1 - ar_mask_prompt
- seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
+ c_quizzes[:, 0] = self.token_forward
+
+ ar_mask_first = torch.zeros(c_quizzes.size(), device=self.device)
+ ar_mask_first[:, : ar_mask_first.size(1) // 2 + 1] = 1
+ ar_mask_second = 1 - ar_mask_first
+ ar_mask_first[:, 0] = 0
+ ar_mask_second[:, 0] = 0
+
+ seq_logproba = torch.empty(ar_mask_first.size(0), device=self.device)