3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import torch, torchvision
9 import torch.nn.functional as F
12 "white": [255, 255, 255],
16 "yellow": [255, 255, 0],
18 "maroon": [128, 0, 0],
19 "dark_red": [139, 0, 0],
20 "brown": [165, 42, 42],
21 "firebrick": [178, 34, 34],
22 "crimson": [220, 20, 60],
23 "tomato": [255, 99, 71],
24 "coral": [255, 127, 80],
25 "indian_red": [205, 92, 92],
26 "light_coral": [240, 128, 128],
27 "dark_salmon": [233, 150, 122],
28 "salmon": [250, 128, 114],
29 "light_salmon": [255, 160, 122],
30 "orange_red": [255, 69, 0],
31 "dark_orange": [255, 140, 0],
32 "orange": [255, 165, 0],
33 "gold": [255, 215, 0],
34 "dark_golden_rod": [184, 134, 11],
35 "golden_rod": [218, 165, 32],
36 "pale_golden_rod": [238, 232, 170],
37 "dark_khaki": [189, 183, 107],
38 "khaki": [240, 230, 140],
39 "olive": [128, 128, 0],
40 "yellow_green": [154, 205, 50],
41 "dark_olive_green": [85, 107, 47],
42 "olive_drab": [107, 142, 35],
43 "lawn_green": [124, 252, 0],
44 "chartreuse": [127, 255, 0],
45 "green_yellow": [173, 255, 47],
46 "dark_green": [0, 100, 0],
47 "forest_green": [34, 139, 34],
49 "lime_green": [50, 205, 50],
50 "light_green": [144, 238, 144],
51 "pale_green": [152, 251, 152],
52 "dark_sea_green": [143, 188, 143],
53 "medium_spring_green": [0, 250, 154],
54 "spring_green": [0, 255, 127],
55 "sea_green": [46, 139, 87],
56 "medium_aqua_marine": [102, 205, 170],
57 "medium_sea_green": [60, 179, 113],
58 "light_sea_green": [32, 178, 170],
59 "dark_slate_gray": [47, 79, 79],
60 "teal": [0, 128, 128],
61 "dark_cyan": [0, 139, 139],
62 "aqua": [0, 255, 255],
63 "cyan": [0, 255, 255],
64 "light_cyan": [224, 255, 255],
65 "dark_turquoise": [0, 206, 209],
66 "turquoise": [64, 224, 208],
67 "medium_turquoise": [72, 209, 204],
68 "pale_turquoise": [175, 238, 238],
69 "aqua_marine": [127, 255, 212],
70 "powder_blue": [176, 224, 230],
71 "cadet_blue": [95, 158, 160],
72 "steel_blue": [70, 130, 180],
73 "corn_flower_blue": [100, 149, 237],
74 "deep_sky_blue": [0, 191, 255],
75 "dodger_blue": [30, 144, 255],
76 "light_blue": [173, 216, 230],
77 "sky_blue": [135, 206, 235],
78 "light_sky_blue": [135, 206, 250],
79 "midnight_blue": [25, 25, 112],
81 "dark_blue": [0, 0, 139],
82 "medium_blue": [0, 0, 205],
83 "royal_blue": [65, 105, 225],
84 "blue_violet": [138, 43, 226],
85 "indigo": [75, 0, 130],
86 "dark_slate_blue": [72, 61, 139],
87 "slate_blue": [106, 90, 205],
88 "medium_slate_blue": [123, 104, 238],
89 "medium_purple": [147, 112, 219],
90 "dark_magenta": [139, 0, 139],
91 "dark_violet": [148, 0, 211],
92 "dark_orchid": [153, 50, 204],
93 "medium_orchid": [186, 85, 211],
94 "purple": [128, 0, 128],
95 "thistle": [216, 191, 216],
96 "plum": [221, 160, 221],
97 "violet": [238, 130, 238],
98 "magenta": [255, 0, 255],
99 "orchid": [218, 112, 214],
100 "medium_violet_red": [199, 21, 133],
101 "pale_violet_red": [219, 112, 147],
102 "deep_pink": [255, 20, 147],
103 "hot_pink": [255, 105, 180],
104 "light_pink": [255, 182, 193],
105 "pink": [255, 192, 203],
106 "antique_white": [250, 235, 215],
107 "beige": [245, 245, 220],
108 "bisque": [255, 228, 196],
109 "blanched_almond": [255, 235, 205],
110 "wheat": [245, 222, 179],
111 "corn_silk": [255, 248, 220],
112 "lemon_chiffon": [255, 250, 205],
113 "light_golden_rod_yellow": [250, 250, 210],
114 "light_yellow": [255, 255, 224],
115 "saddle_brown": [139, 69, 19],
116 "sienna": [160, 82, 45],
117 "chocolate": [210, 105, 30],
118 "peru": [205, 133, 63],
119 "sandy_brown": [244, 164, 96],
120 "burly_wood": [222, 184, 135],
121 "tan": [210, 180, 140],
122 "rosy_brown": [188, 143, 143],
123 "moccasin": [255, 228, 181],
124 "navajo_white": [255, 222, 173],
125 "peach_puff": [255, 218, 185],
126 "misty_rose": [255, 228, 225],
127 "lavender_blush": [255, 240, 245],
128 "linen": [250, 240, 230],
129 "old_lace": [253, 245, 230],
130 "papaya_whip": [255, 239, 213],
131 "sea_shell": [255, 245, 238],
132 "mint_cream": [245, 255, 250],
133 "slate_gray": [112, 128, 144],
134 "light_slate_gray": [119, 136, 153],
135 "light_steel_blue": [176, 196, 222],
136 "lavender": [230, 230, 250],
137 "floral_white": [255, 250, 240],
138 "alice_blue": [240, 248, 255],
139 "ghost_white": [248, 248, 255],
140 "honeydew": [240, 255, 240],
141 "ivory": [255, 255, 240],
142 "azure": [240, 255, 255],
143 "snow": [255, 250, 250],
144 "silver": [192, 192, 192],
145 "gainsboro": [220, 220, 220],
146 "white_smoke": [245, 245, 245],
149 color_name2id = dict([(n, k) for k, n in enumerate(color_name2rgb.keys())])
150 color_id2name = dict([(k, n) for k, n in enumerate(color_name2rgb.keys())])
152 ######################################################################
155 def all_properties(height, width, nb_squares, square_i, square_j, square_c):
158 for r, c_r in [(k, color_id2name[square_c[k].item()]) for k in range(nb_squares)]:
159 s += [f"there is {c_r}"]
161 if square_i[r] >= height - height // 3:
162 s += [f"{c_r} bottom"]
163 if square_i[r] < height // 3:
165 if square_j[r] >= width - width // 3:
166 s += [f"{c_r} right"]
167 if square_j[r] < width // 3:
171 (k, color_id2name[square_c[k].item()]) for k in range(nb_squares)
173 if square_i[r] > square_i[t]:
174 s += [f"{c_r} below {c_t}"]
175 if square_i[r] < square_i[t]:
176 s += [f"{c_r} above {c_t}"]
177 if square_j[r] > square_j[t]:
178 s += [f"{c_r} right of {c_t}"]
179 if square_j[r] < square_j[t]:
180 s += [f"{c_r} left of {c_t}"]
185 ######################################################################
187 # Generates sequences
195 max_nb_properties=10,
200 assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1
206 nb_squares = torch.randint(max_nb_squares, (1,)) + 1
207 square_position = torch.randperm(height * width)[:nb_squares]
209 # color 0 is white and reserved for the background
210 square_c = torch.randperm(nb_colors)[:nb_squares] + 1
211 square_i = square_position.div(width, rounding_mode="floor")
212 square_j = square_position % width
214 img = torch.zeros(height * width, dtype=torch.int64)
215 for k in range(nb_squares):
216 img[square_position[k]] = square_c[k]
218 # generates all the true properties
220 s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
222 if pruner is not None:
223 s = list(filter(pruner, s))
225 # picks at most max_nb_properties at random
227 nb_properties = torch.randint(max_nb_properties, (1,)) + 1
229 " <sep> ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]])
231 + " ".join([f"{color_id2name[n.item()]}" for n in img])
239 ######################################################################
241 # Extracts the image after <img> in descr as a 1x3xHxW tensor
244 def descr2img(descr, n, height, width):
246 if type(descr) == list:
247 return torch.cat([descr2img(d, n, height, width) for d in descr], 0)
250 return torch.cat([descr2img(descr, k, height, width) for k in n], 0).unsqueeze(
256 return color_name2rgb[t]
258 return [128, 128, 128]
260 d = descr.split("<img>")
261 d = d[n + 1] if len(d) > n + 1 else ""
262 d = d.strip().split(" ")[: height * width]
263 d = d + ["<unk>"] * (height * width - len(d))
264 d = [token2color(t) for t in d]
265 img = torch.tensor(d).permute(1, 0)
266 img = img.reshape(1, 3, height, width)
271 ######################################################################
273 # Returns all the properties of the image after <img> in descr
276 def descr2properties(descr, height, width):
278 if type(descr) == list:
279 return [descr2properties(d, height, width) for d in descr]
281 d = descr.split("<img>")
282 img_tokens = d[-1] if len(d) > 1 else ""
283 img_tokens = img_tokens.strip().split(" ")[: height * width]
284 if len(img_tokens) != height * width:
288 for k, x in enumerate(img_tokens):
289 if x != color_id2name[0]:
290 if x in color_name2rgb:
295 seen[x] = (color_name2id[x], k // width, k % width)
297 square_infos = tuple(zip(*seen.values()))
300 square_c = torch.tensor(square_infos[0])
301 square_i = torch.tensor(square_infos[1])
302 square_j = torch.tensor(square_infos[2])
304 square_c = torch.tensor([])
305 square_i = torch.tensor([])
306 square_j = torch.tensor([])
308 s = all_properties(height, width, len(seen), square_i, square_j, square_c)
313 ######################################################################
315 # Returns a triplet composed of (1) the total number of properties
316 # before <img> in descr, (2) the total number of properties the image
317 # after <img> verifies, and (3) the number of properties in (1) not in
321 def nb_properties(descr, height, width, pruner=None):
323 if type(descr) == list:
324 return [nb_properties(d, height, width, pruner) for d in descr]
326 d = descr.split("<img>", 1)
329 d = d[0].strip().split("<sep>")
330 d = [x.strip() for x in d]
332 all_properties = set(descr2properties(descr, height, width))
335 requested_properties = set(d)
337 requested_properties = set(filter(pruner, d))
339 missing_properties = requested_properties - all_properties
341 return (len(requested_properties), len(all_properties), len(missing_properties))
344 ######################################################################
346 if __name__ == "__main__":
348 descr = generate(nb=1, height=12, width=16)
350 print(nb_properties(descr, height=12, width=16))
352 with open(f"picoclvr_example_{n:02d}.txt", "w") as f:
356 img = descr2img(descr, n=0, height=12, width=16)
358 img = F.pad(img, (1, 1, 1, 1), value=64)
360 torchvision.utils.save_image(
362 f"picoclvr_example_{n:02d}.png",
370 start_time = time.perf_counter()
371 descr = generate(nb=1000, height=12, width=16)
372 end_time = time.perf_counter()
373 print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
375 ######################################################################