("gray", [192, 192, 192]),
]
- def __init__(
- self,
- ):
+ def __init__(self, device=torch.device("cpu")):
self.colors = torch.tensor([c for _, c in self.named_colors])
self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
self.height = 10
self.width = 10
+ self.device = device
######################################################################
return len(self.colors)
def rec_coo(self, x, n, min_height=3, min_width=3):
+ K = 3
+ N = 4000
+
+ while True:
+ v = (
+ (
+ torch.rand(N * K, self.height + 1, device=self.device)
+ .sort(dim=-1)
+ .indices
+ < 2
+ )
+ .long()
+ .cumsum(dim=1)
+ == 1
+ ).long()
+
+ h = (
+ (
+ torch.rand(N * K, self.width + 1, device=self.device)
+ .sort(dim=-1)
+ .indices
+ < 2
+ )
+ .long()
+ .cumsum(dim=1)
+ == 1
+ ).long()
+
+ i = torch.logical_and(
+ v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width
+ )
+
+ v, h = v[i], h[i]
+ v = v[: v.size(0) - v.size(0) % K]
+ h = h[: h.size(0) - h.size(0) % K]
+ v = v.reshape(v.size(0) // K, K, -1)
+ h = h.reshape(h.size(0) // K, K, -1)
+
+ r = v[:, :, :, None] * h[:, :, None, :]
+
+ valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1
+
+ v = v[valid]
+ h = h[valid]
+
+ if v.size(0) > 0:
+ break
+
+ av = torch.arange(v.size(2), device=self.device)[None, :]
+ ah = torch.arange(h.size(2), device=self.device)[None, :]
+
+ return [
+ (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1)
+ for i1, j1, i2, j2 in zip(
+ v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values,
+ h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values,
+ (v[0] * av).max(dim=-1).values,
+ (h[0] * ah).max(dim=-1).values,
+ )
+ ]
+
+ def rec_coo_(self, x, n, min_height=3, min_width=3):
collision = x.new(x.size())
while True:
collision[...] = 0
######################################################################
- def generate_prompts_and_answers(self, nb):
+ def generate_prompts_and_answers(self, nb, device="cpu"):
tasks = [
self.task_replace_color,
self.task_move,
self.task_frame,
self.task_detect,
]
- prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
- answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)
+ prompts = torch.zeros(
+ nb, self.height, self.width * 3, dtype=torch.int64, device=self.device
+ )
+ answers = torch.zeros(
+ nb, self.height, self.width, dtype=torch.int64, device=self.device
+ )
w = self.width
for prompt, answer in tqdm.tqdm(