("gray", [192, 192, 192]),
]
- def __init__(
- self,
- ):
+ def __init__(self, device=torch.device("cpu")):
self.colors = torch.tensor([c for _, c in self.named_colors])
self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
self.height = 10
self.width = 10
+ self.device = device
######################################################################
def nb_token_values(self):
return len(self.colors)
+ # That's quite a tensorial spaghetti mess to sample
+ # non-overlapping rectangles quickly, but made the generation of
+ # 100k samples from 1h50 with a lame pure python code to 4min with
+ # this one.
def rec_coo(self, x, n, min_height=3, min_width=3):
+ K = 3
+ N = 1000
+
+ while True:
+ v = (
+ (
+ torch.rand(N * K, self.height + 1, device=self.device)
+ .sort(dim=-1)
+ .indices
+ < 2
+ )
+ .long()
+ .cumsum(dim=1)
+ == 1
+ ).long()
+
+ h = (
+ (
+ torch.rand(N * K, self.width + 1, device=self.device)
+ .sort(dim=-1)
+ .indices
+ < 2
+ )
+ .long()
+ .cumsum(dim=1)
+ == 1
+ ).long()
+
+ i = torch.logical_and(
+ v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width
+ )
+
+ v, h = v[i], h[i]
+ v = v[: v.size(0) - v.size(0) % K]
+ h = h[: h.size(0) - h.size(0) % K]
+ v = v.reshape(v.size(0) // K, K, -1)
+ h = h.reshape(h.size(0) // K, K, -1)
+
+ r = v[:, :, :, None] * h[:, :, None, :]
+
+ valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1
+
+ v = v[valid]
+ h = h[valid]
+
+ if v.size(0) > 0:
+ break
+
+ av = torch.arange(v.size(2), device=self.device)[None, :]
+ ah = torch.arange(h.size(2), device=self.device)[None, :]
+
+ return [
+ (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1)
+ for i1, j1, i2, j2 in zip(
+ v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values,
+ h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values,
+ (v[0] * av).max(dim=-1).values,
+ (h[0] * ah).max(dim=-1).values,
+ )
+ ]
+
+ def rec_coo_(self, x, n, min_height=3, min_width=3):
collision = x.new(x.size())
while True:
collision[...] = 0
######################################################################
- def generate_prompts_and_answers(self, nb):
+ def generate_prompts_and_answers(self, nb, device="cpu"):
tasks = [
self.task_replace_color,
self.task_move,
f_B = answer
task = tasks[torch.randint(len(tasks), (1,))]
task(A, f_A, B, f_B)
+
return prompts.flatten(1), answers.flatten(1)
def save_quizzes(