return y
def start_error(self, x):
- i = torch.arange(self.height).reshape(1,-1,1).expand_as(x)
- j = torch.arange(self.width).reshape(1,1,-1).expand_as(x)
+ i = torch.arange(self.height, device=x.device).reshape(1,-1,1).expand_as(x)
+ j = torch.arange(self.width, device=x.device).reshape(1,1,-1).expand_as(x)
ri = (x == self.height * self.width).long().sum(dim=-1).argmax(-1).view(-1,1,1)
rj = (x == self.height * self.width).long().sum(dim=-2).argmax(-1).view(-1,1,1)