Initial commit.
[picoclvr.git] / picoclvr.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, torchvision
9 import torch.nn.functional as F
10
11 colors = [
12     [255, 255, 255],
13     [255, 0, 0],
14     [0, 128, 0],
15     [0, 0, 255],
16     [255, 255, 0],
17     [0, 0, 0],
18     [128, 0, 0],
19     [139, 0, 0],
20     [165, 42, 42],
21     [178, 34, 34],
22     [220, 20, 60],
23     [255, 99, 71],
24     [255, 127, 80],
25     [205, 92, 92],
26     [240, 128, 128],
27     [233, 150, 122],
28     [250, 128, 114],
29     [255, 160, 122],
30     [255, 69, 0],
31     [255, 140, 0],
32     [255, 165, 0],
33     [255, 215, 0],
34     [184, 134, 11],
35     [218, 165, 32],
36     [238, 232, 170],
37     [189, 183, 107],
38     [240, 230, 140],
39     [128, 128, 0],
40     [154, 205, 50],
41     [85, 107, 47],
42     [107, 142, 35],
43     [124, 252, 0],
44     [127, 255, 0],
45     [173, 255, 47],
46     [0, 100, 0],
47     [34, 139, 34],
48     [0, 255, 0],
49     [50, 205, 50],
50     [144, 238, 144],
51     [152, 251, 152],
52     [143, 188, 143],
53     [0, 250, 154],
54     [0, 255, 127],
55     [46, 139, 87],
56     [102, 205, 170],
57     [60, 179, 113],
58     [32, 178, 170],
59     [47, 79, 79],
60     [0, 128, 128],
61     [0, 139, 139],
62     [0, 255, 255],
63     [0, 255, 255],
64     [224, 255, 255],
65     [0, 206, 209],
66     [64, 224, 208],
67     [72, 209, 204],
68     [175, 238, 238],
69     [127, 255, 212],
70     [176, 224, 230],
71     [95, 158, 160],
72     [70, 130, 180],
73     [100, 149, 237],
74     [0, 191, 255],
75     [30, 144, 255],
76     [173, 216, 230],
77     [135, 206, 235],
78     [135, 206, 250],
79     [25, 25, 112],
80     [0, 0, 128],
81     [0, 0, 139],
82     [0, 0, 205],
83     [65, 105, 225],
84     [138, 43, 226],
85     [75, 0, 130],
86     [72, 61, 139],
87     [106, 90, 205],
88     [123, 104, 238],
89     [147, 112, 219],
90     [139, 0, 139],
91     [148, 0, 211],
92     [153, 50, 204],
93     [186, 85, 211],
94     [128, 0, 128],
95     [216, 191, 216],
96     [221, 160, 221],
97     [238, 130, 238],
98     [255, 0, 255],
99     [218, 112, 214],
100     [199, 21, 133],
101     [219, 112, 147],
102     [255, 20, 147],
103     [255, 105, 180],
104     [255, 182, 193],
105     [255, 192, 203],
106     [250, 235, 215],
107     [245, 245, 220],
108     [255, 228, 196],
109     [255, 235, 205],
110     [245, 222, 179],
111     [255, 248, 220],
112     [255, 250, 205],
113     [250, 250, 210],
114     [255, 255, 224],
115     [139, 69, 19],
116     [160, 82, 45],
117     [210, 105, 30],
118     [205, 133, 63],
119     [244, 164, 96],
120     [222, 184, 135],
121     [210, 180, 140],
122     [188, 143, 143],
123     [255, 228, 181],
124     [255, 222, 173],
125     [255, 218, 185],
126     [255, 228, 225],
127     [255, 240, 245],
128     [250, 240, 230],
129     [253, 245, 230],
130     [255, 239, 213],
131     [255, 245, 238],
132     [245, 255, 250],
133     [112, 128, 144],
134     [119, 136, 153],
135     [176, 196, 222],
136     [230, 230, 250],
137     [255, 250, 240],
138     [240, 248, 255],
139     [248, 248, 255],
140     [240, 255, 240],
141     [255, 255, 240],
142     [240, 255, 255],
143     [255, 250, 250],
144     [192, 192, 192],
145     [220, 220, 220],
146     [245, 245, 245],
147 ]
148
149 color_names = [
150     "white",
151     "red",
152     "green",
153     "blue",
154     "yellow",
155     "black",
156     "maroon",
157     "dark_red",
158     "brown",
159     "firebrick",
160     "crimson",
161     "tomato",
162     "coral",
163     "indian_red",
164     "light_coral",
165     "dark_salmon",
166     "salmon",
167     "light_salmon",
168     "orange_red",
169     "dark_orange",
170     "orange",
171     "gold",
172     "dark_golden_rod",
173     "golden_rod",
174     "pale_golden_rod",
175     "dark_khaki",
176     "khaki",
177     "olive",
178     "yellow_green",
179     "dark_olive_green",
180     "olive_drab",
181     "lawn_green",
182     "chartreuse",
183     "green_yellow",
184     "dark_green",
185     "forest_green",
186     "lime",
187     "lime_green",
188     "light_green",
189     "pale_green",
190     "dark_sea_green",
191     "medium_spring_green",
192     "spring_green",
193     "sea_green",
194     "medium_aqua_marine",
195     "medium_sea_green",
196     "light_sea_green",
197     "dark_slate_gray",
198     "teal",
199     "dark_cyan",
200     "aqua",
201     "cyan",
202     "light_cyan",
203     "dark_turquoise",
204     "turquoise",
205     "medium_turquoise",
206     "pale_turquoise",
207     "aqua_marine",
208     "powder_blue",
209     "cadet_blue",
210     "steel_blue",
211     "corn_flower_blue",
212     "deep_sky_blue",
213     "dodger_blue",
214     "light_blue",
215     "sky_blue",
216     "light_sky_blue",
217     "midnight_blue",
218     "navy",
219     "dark_blue",
220     "medium_blue",
221     "royal_blue",
222     "blue_violet",
223     "indigo",
224     "dark_slate_blue",
225     "slate_blue",
226     "medium_slate_blue",
227     "medium_purple",
228     "dark_magenta",
229     "dark_violet",
230     "dark_orchid",
231     "medium_orchid",
232     "purple",
233     "thistle",
234     "plum",
235     "violet",
236     "magenta",
237     "orchid",
238     "medium_violet_red",
239     "pale_violet_red",
240     "deep_pink",
241     "hot_pink",
242     "light_pink",
243     "pink",
244     "antique_white",
245     "beige",
246     "bisque",
247     "blanched_almond",
248     "wheat",
249     "corn_silk",
250     "lemon_chiffon",
251     "light_golden_rod_yellow",
252     "light_yellow",
253     "saddle_brown",
254     "sienna",
255     "chocolate",
256     "peru",
257     "sandy_brown",
258     "burly_wood",
259     "tan",
260     "rosy_brown",
261     "moccasin",
262     "navajo_white",
263     "peach_puff",
264     "misty_rose",
265     "lavender_blush",
266     "linen",
267     "old_lace",
268     "papaya_whip",
269     "sea_shell",
270     "mint_cream",
271     "slate_gray",
272     "light_slate_gray",
273     "light_steel_blue",
274     "lavender",
275     "floral_white",
276     "alice_blue",
277     "ghost_white",
278     "honeydew",
279     "ivory",
280     "azure",
281     "snow",
282     "silver",
283     "gainsboro",
284     "white_smoke",
285 ]
286
287 color_id = dict([(n, k) for k, n in enumerate(color_names)])
288 color_tokens = dict([(n, c) for n, c in zip(color_names, colors)])
289
290 ######################################################################
291
292
293 def all_properties(height, width, nb_squares, square_i, square_j, square_c):
294     s = []
295
296     for r, c_r in [(k, color_names[square_c[k]]) for k in range(nb_squares)]:
297         s += [f"there is {c_r}"]
298
299         if square_i[r] >= height - height // 3:
300             s += [f"{c_r} bottom"]
301         if square_i[r] < height // 3:
302             s += [f"{c_r} top"]
303         if square_j[r] >= width - width // 3:
304             s += [f"{c_r} right"]
305         if square_j[r] < width // 3:
306             s += [f"{c_r} left"]
307
308         for t, c_t in [(k, color_names[square_c[k]]) for k in range(nb_squares)]:
309             if square_i[r] > square_i[t]:
310                 s += [f"{c_r} below {c_t}"]
311             if square_i[r] < square_i[t]:
312                 s += [f"{c_r} above {c_t}"]
313             if square_j[r] > square_j[t]:
314                 s += [f"{c_r} right of {c_t}"]
315             if square_j[r] < square_j[t]:
316                 s += [f"{c_r} left of {c_t}"]
317
318     return s
319
320
321 ######################################################################
322
323 # Generates sequences
324
325
326 def generate(
327     nb,
328     height,
329     width,
330     max_nb_squares=5,
331     max_nb_properties=10,
332     nb_colors=5,
333     pruner=None,
334 ):
335
336     assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
337
338     descr = []
339
340     for n in range(nb):
341
342         nb_squares = torch.randint(max_nb_squares, (1,)) + 1
343         square_position = torch.randperm(height * width)[:nb_squares]
344
345         # color 0 is white and reserved for the background
346         square_c = torch.randperm(nb_colors)[:nb_squares] + 1
347         square_i = square_position.div(width, rounding_mode="floor")
348         square_j = square_position % width
349
350         img = [0] * height * width
351         for k in range(nb_squares):
352             img[square_position[k]] = square_c[k]
353
354         # generates all the true properties
355
356         s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
357
358         if pruner is not None:
359             s = list(filter(pruner, s))
360
361         # picks at most max_nb_properties at random
362
363         nb_properties = torch.randint(max_nb_properties, (1,)) + 1
364         s = (
365             " <sep> ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]])
366             + " <img> "
367             + " ".join([f"{color_names[n]}" for n in img])
368         )
369
370         descr += [s]
371
372     return descr
373
374
375 ######################################################################
376
377 # Extracts the image after <img> in descr as a 1x3xHxW tensor
378
379
380 def descr2img(descr, n, height, width):
381
382     if type(descr) == list:
383         return torch.cat([descr2img(d, n, height, width) for d in descr], 0)
384
385     if type(n) == list:
386         return torch.cat([descr2img(descr, k, height, width) for k in n], 0).unsqueeze(
387             0
388         )
389
390     def token2color(t):
391         try:
392             return color_tokens[t]
393         except KeyError:
394             return [128, 128, 128]
395
396     d = descr.split("<img>")
397     d = d[n + 1] if len(d) > n + 1 else ""
398     d = d.strip().split(" ")[: height * width]
399     d = d + ["<unk>"] * (height * width - len(d))
400     d = [token2color(t) for t in d]
401     img = torch.tensor(d).permute(1, 0)
402     img = img.reshape(1, 3, height, width)
403
404     return img
405
406
407 ######################################################################
408
409 # Returns all the properties of the image after <img> in descr
410
411
412 def descr2properties(descr, height, width):
413
414     if type(descr) == list:
415         return [descr2properties(d, height, width) for d in descr]
416
417     d = descr.split("<img>")
418     d = d[-1] if len(d) > 1 else ""
419     d = d.strip().split(" ")[: height * width]
420     if len(d) != height * width:
421         return []
422
423     seen = {}
424     for k, x in enumerate(d):
425         if x != color_names[0]:
426             if x in color_tokens:
427                 if x in seen:
428                     return []
429             else:
430                 return []
431             seen[x] = (color_id[x], k // width, k % width)
432
433     square_infos = tuple(zip(*seen.values()))
434
435     if square_infos:
436         square_c = torch.tensor(square_infos[0])
437         square_i = torch.tensor(square_infos[1])
438         square_j = torch.tensor(square_infos[2])
439     else:
440         square_c = torch.tensor([])
441         square_i = torch.tensor([])
442         square_j = torch.tensor([])
443
444     s = all_properties(height, width, len(seen), square_i, square_j, square_c)
445
446     return s
447
448
449 ######################################################################
450
451 # Returns a triplet composed of (1) the total number of properties
452 # before <img> in descr, (2) the total number of properties the image
453 # after <img> verifies, and (3) the number of properties in (1) not in
454 # (2)
455
456
457 def nb_properties(descr, height, width, pruner=None):
458
459     if type(descr) == list:
460         return [nb_properties(d, height, width, pruner) for d in descr]
461
462     d = descr.split("<img>", 1)
463     if len(d) == 0:
464         return 0
465     d = d[0].strip().split("<sep>")
466     d = [x.strip() for x in d]
467
468     all_properties = set(descr2properties(descr, height, width))
469
470     if pruner is None:
471         requested_properties = set(d)
472     else:
473         requested_properties = set(filter(pruner, d))
474
475     missing_properties = requested_properties - all_properties
476
477     return (len(requested_properties), len(all_properties), len(missing_properties))
478
479
480 ######################################################################
481
482 if __name__ == "__main__":
483     for n in range(16):
484         descr = generate(nb=1, height=12, width=16)
485
486         print(nb_properties(descr, height=12, width=16))
487
488         with open(f"picoclvr_example_{n:02d}.txt", "w") as f:
489             for d in descr:
490                 f.write(f"{d}\n\n")
491
492         img = descr2img(descr, n=0, height=12, width=16)
493         if img.size(0) == 1:
494             img = F.pad(img, (1, 1, 1, 1), value=64)
495
496         torchvision.utils.save_image(
497             img / 255.0,
498             f"picoclvr_example_{n:02d}.png",
499             padding=1,
500             nrow=4,
501             pad_value=0.8,
502         )
503
504     import time
505
506     start_time = time.perf_counter()
507     descr = generate(nb=1000, height=12, width=16)
508     end_time = time.perf_counter()
509     print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
510
511 ######################################################################