- dist[1:-1, 1:-1] += (X != 0).long() * self.height * self.width
- dist[0, :] = self.height * self.width
- dist[-1, :] = self.height * self.width
- dist[:, 0] = self.height * self.width
- dist[:, -1] = self.height * self.width
- # dist += torch.rand(dist.size())
-
- i, j = i0 + 1, j0 + 1
- while i != i1 + 1 or j != j1 + 1:
- f_X[i - 1, j - 1] = c[2]
- r, s, t, u = (
- dist[i - 1, j],
- dist[i, j - 1],
- dist[i + 1, j],
- dist[i, j + 1],
- )
- m = min(r, s, t, u)
- if r == m:
- i = i - 1
- elif t == m:
- i = i + 1
- elif s == m:
- j = j - 1
- else:
- j = j + 1
+ dist0[...] = 1
+ dist0[1:-1, 1:-1] = (X != 0).long()
+ dist0[...] = self.compute_distance(dist0, i0 + 1, j0 + 1)