- def start(self, nb):
- return (
- torch.arange(self.height * self.width)
- .reshape(1, 1, self.height, self.width)
- .expand(nb, -1, -1, -1)
+ def start_random(self, nb):
+ y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
+
+ # m = (torch.rand(y.size()).sort(dim=-1).indices < y.size(1) // 2).long()
+
+ i = (
+ torch.arange(self.height)
+ .reshape(1, -1, 1)
+ .expand(nb, self.height, self.width)
+ )
+ j = (
+ torch.arange(self.width)
+ .reshape(1, 1, -1)
+ .expand(nb, self.height, self.width)
+ )
+
+ ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1)
+ rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1)
+
+ m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
+
+ y = (y * m + self.height * self.width * (1 - m)).reshape(
+ nb, self.height, self.width
+ )
+
+ return y
+
+ def start_error(self, x):
+ i = torch.arange(self.height, device=x.device).reshape(1, -1, 1).expand_as(x)
+ j = torch.arange(self.width, device=x.device).reshape(1, 1, -1).expand_as(x)
+
+ ri = (
+ (x == self.height * self.width).long().sum(dim=-1).argmax(-1).view(-1, 1, 1)
+ )
+ rj = (
+ (x == self.height * self.width).long().sum(dim=-2).argmax(-1).view(-1, 1, 1)