+ X[...] = 0
+ f_X[...] = 0
+
+ c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+
+ for r in range(nb_rec):
+ while True:
+ i1, i2 = torch.randint(self.height - 2, (2,)) + 1
+ j1, j2 = torch.randint(self.width - 2, (2,)) + 1
+ if (
+ i2 >= i1
+ and j2 >= j1
+ and max(i2 - i1, j2 - j1) >= 2
+ and min(i2 - i1, j2 - j1) <= 3
+ ):
+ break
+ X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
+
+ i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj
+
+ i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1
+ i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2
+
+ i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj
+ i1, i2 = i1.long() + di, i2.long() + di
+ j1, j2 = j1.long() + dj, j2.long() + dj
+ if i1 > i2:
+ i1, i2 = i2, i1
+ if j1 > j2:
+ j1, j2 = j2, j1
+
+ f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
+
+ n = F.one_hot(X.flatten()).sum(dim=0)[1:]
+ if (
+ n.sum() > self.height * self.width // 4
+ and (n > 0).long().sum() == nb_rec
+ ):