+ # That's quite a tensorial spaghetti mess to sample
+ # non-overlapping rectangles quickly, but made the generation of
+ # 100k samples go from 1h50 with a lame pure python code to 3min30s
+ # with this one.
+ def rec_coo(self, nb_rec, min_height=3, min_width=3):
+ nb_trials = 200
+
+ while True:
+ v = (
+ (
+ torch.rand(nb_trials * nb_rec, self.height + 1, device=self.device)
+ .sort(dim=-1)
+ .indices
+ < 2
+ )
+ .long()
+ .cumsum(dim=1)
+ == 1
+ ).long()
+
+ h = (
+ (
+ torch.rand(nb_trials * nb_rec, self.width + 1, device=self.device)
+ .sort(dim=-1)
+ .indices
+ < 2
+ )
+ .long()
+ .cumsum(dim=1)
+ == 1
+ ).long()
+
+ i = torch.logical_and(
+ v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width
+ )
+
+ v, h = v[i], h[i]
+ v = v[: v.size(0) - v.size(0) % nb_rec]
+ h = h[: h.size(0) - h.size(0) % nb_rec]
+ v = v.reshape(v.size(0) // nb_rec, nb_rec, -1)
+ h = h.reshape(h.size(0) // nb_rec, nb_rec, -1)
+
+ r = v[:, :, :, None] * h[:, :, None, :]
+
+ valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1
+
+ v = v[valid]
+ h = h[valid]
+
+ if v.size(0) > 0:
+ break
+
+ av = torch.arange(v.size(2), device=self.device)[None, :]
+ ah = torch.arange(h.size(2), device=self.device)[None, :]
+
+ return [
+ (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1)
+ for i1, j1, i2, j2 in zip(
+ v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values,
+ h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values,
+ (v[0] * av).max(dim=-1).values,
+ (h[0] * ah).max(dim=-1).values,
+ )
+ ]
+
+ def rec_coo_(self, x, n, min_height=3, min_width=3):