- d = descr.split("<img>")
- d = d[n + 1] if len(d) > n + 1 else ""
- d = d.strip().split(" ")[: height * width]
- d = d + ["<unk>"] * (height * width - len(d))
- d = [token2color(t) for t in d]
- img = torch.tensor(d).permute(1, 0)
- img = img.reshape(1, 3, height, width)
+ for d in descr:
+ d = d.split("<img>")[1]
+ d = d.strip().split(" ")[: height * width]
+ d = d + ["<unk>"] * (height * width - len(d))
+ d = [token2color(t) for t in d]
+ img = torch.tensor(d).permute(1, 0).reshape(1, 3, height, width)
+ result.append(img)