ar_mask_solve = 1 - ar_mask_prompt
seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
- warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
- temperature = 10
+ if reverse_cleanup:
+ warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
+ temperature = 10.0
+ else:
+ temperature = 1.0
# warnings.warn("noise injection", RuntimeWarning)
# noise_std = torch.rand(1).item()