- k = torch.arange(result.size(1), device=result.device)[None,:]
- last_output_idx = ((result == self.t_output) * k).max(dim=1,keep_dim=True)
- first_prog_idx = ((result == self.t_prog) * k).min(dim=1,keep_dim=True)
+ k = torch.arange(result.size(1), device=result.device)[None, :]
+ last_output_idx = ((result == self.t_output) * k).max(dim=1, keep_dim=True)
+ first_prog_idx = ((result == self.t_prog) * k).min(dim=1, keep_dim=True)