nb_colors=5,
pruner=None,
):
-
assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1
descr = []
for n in range(nb):
-
nb_squares = torch.randint(max_nb_squares, (1,)) + 1
square_position = torch.randperm(height * width)[:nb_squares]
# Extracts the image after <img> in descr as a 1x3xHxW tensor
-def descr2img(descr, n, height, width):
-
- if type(descr) == list:
- return torch.cat([descr2img(d, n, height, width) for d in descr], 0)
-
- if type(n) == list:
- return torch.cat([descr2img(descr, k, height, width) for k in n], 0).unsqueeze(
- 0
- )
+def descr2img(descr, height, width):
+ result = []
def token2color(t):
try:
except KeyError:
return [128, 128, 128]
- d = descr.split("<img>")
- d = d[n + 1] if len(d) > n + 1 else ""
- d = d.strip().split(" ")[: height * width]
- d = d + ["<unk>"] * (height * width - len(d))
- d = [token2color(t) for t in d]
- img = torch.tensor(d).permute(1, 0)
- img = img.reshape(1, 3, height, width)
+ for d in descr:
+ d = d.split("<img>")[1]
+ d = d.strip().split(" ")[: height * width]
+ d = d + ["<unk>"] * (height * width - len(d))
+ d = [token2color(t) for t in d]
+ img = torch.tensor(d).permute(1, 0).reshape(1, 3, height, width)
+ result.append(img)
- return img
+ return torch.cat(result, 0)
######################################################################
def descr2properties(descr, height, width):
-
if type(descr) == list:
return [descr2properties(d, height, width) for d in descr]
def nb_properties(descr, height, width, pruner=None):
-
if type(descr) == list:
return [nb_properties(d, height, width, pruner) for d in descr]
for d in descr:
f.write(f"{d}\n\n")
- img = descr2img(descr, n=0, height=12, width=16)
+ img = descr2img(descr, height=12, width=16)
if img.size(0) == 1:
img = F.pad(img, (1, 1, 1, 1), value=64)