ar_mask_solve = 1 - ar_mask_prompt
seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
- # bracketing of the temperature to get the target logproba
+ # bracketing of the temperature to get the target logproba if
+ # min_ave_seq_logproba is not None
- warnings.warn("high temperature!", RuntimeWarning)
temperature = 2
d_temperature = 1 / 3