+ for q, e in zip(l_q, l_d):
+ d = c[e]
+ i, j = q % self.height, q // self.height
+ if (
+ nb[e] < self.width
+ and X[max(0, i - 1) : i + 2, max(0, j - 1) : j + 2] == 0
+ ).all():
+ X[i, j] = d
+ nb[e] += 1
+
+ l_q = torch.randperm((self.height - 2) * (self.width - 2))[
+ : self.height * self.width // 2
+ ]
+ l_d = torch.randint(N, l_q.size())
+ for q, e in zip(l_q, l_d):
+ d = c[e]
+ i, j = q % (self.height - 2) + 1, q // (self.height - 2) + 1
+ a1, a2, a3 = X[i - 1, j - 1 : j + 2]
+ a8, a4 = X[i, j - 1], X[i, j + 1]
+ a7, a6, a5 = X[i + 1, j - 1 : j + 2]
+ if (
+ X[i, j] == 0
+ and nb[e] < self.width
+ and (a2 == 0 or a2 == d)
+ and (a4 == 0 or a4 == d)
+ and (a6 == 0 or a6 == d)
+ and (a8 == 0 or a8 == d)
+ and (a1 == 0 or a2 == d or a8 == d)
+ and (a3 == 0 or a4 == d or a2 == d)
+ and (a5 == 0 or a6 == d or a4 == d)
+ and (a7 == 0 or a8 == d or a6 == d)
+ ):
+ o = (
+ (a2 != 0).long()
+ + (a4 != 0).long()
+ + (a6 != 0).long()
+ + (a8 != 0).long()
+ )
+ if o <= 1:
+ X[i, j] = d
+ nb[e] += 1 - o
+
+ for e in range(N):
+ for j in range(nb[e]):
+ f_X[e, j] = c[e]
+
+ # @torch.compile