- def start(self, nb):
- return (
- torch.arange(self.height * self.width)
- .reshape(1, 1, self.height, self.width)
- .expand(nb, -1, -1, -1)
+ def start_random(self, nb):
+ y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
+
+ m = (torch.rand(y.size()).sort(dim=-1).indices < y.size(1) // 2).long()
+
+ y = (y * m + self.height * self.width * (1 - m)).reshape(
+ nb, self.height, self.width