Added default configurations and reformated with black.
[mygpt.git] / picoclvr.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, torchvision
9
10 color_tokens = {
11     "white": [255, 255, 255],
12     "red": [255, 0, 0],
13     "green": [0, 128, 0],
14     "blue": [0, 0, 255],
15     "yellow": [255, 255, 0],
16     "black": [0, 0, 0],
17     "maroon": [128, 0, 0],
18     "dark_red": [139, 0, 0],
19     "brown": [165, 42, 42],
20     "firebrick": [178, 34, 34],
21     "crimson": [220, 20, 60],
22     "tomato": [255, 99, 71],
23     "coral": [255, 127, 80],
24     "indian_red": [205, 92, 92],
25     "light_coral": [240, 128, 128],
26     "dark_salmon": [233, 150, 122],
27     "salmon": [250, 128, 114],
28     "light_salmon": [255, 160, 122],
29     "orange_red": [255, 69, 0],
30     "dark_orange": [255, 140, 0],
31     "orange": [255, 165, 0],
32     "gold": [255, 215, 0],
33     "dark_golden_rod": [184, 134, 11],
34     "golden_rod": [218, 165, 32],
35     "pale_golden_rod": [238, 232, 170],
36     "dark_khaki": [189, 183, 107],
37     "khaki": [240, 230, 140],
38     "olive": [128, 128, 0],
39     "yellow_green": [154, 205, 50],
40     "dark_olive_green": [85, 107, 47],
41     "olive_drab": [107, 142, 35],
42     "lawn_green": [124, 252, 0],
43     "chartreuse": [127, 255, 0],
44     "green_yellow": [173, 255, 47],
45     "dark_green": [0, 100, 0],
46     "forest_green": [34, 139, 34],
47     "lime": [0, 255, 0],
48     "lime_green": [50, 205, 50],
49     "light_green": [144, 238, 144],
50     "pale_green": [152, 251, 152],
51     "dark_sea_green": [143, 188, 143],
52     "medium_spring_green": [0, 250, 154],
53     "spring_green": [0, 255, 127],
54     "sea_green": [46, 139, 87],
55     "medium_aqua_marine": [102, 205, 170],
56     "medium_sea_green": [60, 179, 113],
57     "light_sea_green": [32, 178, 170],
58     "dark_slate_gray": [47, 79, 79],
59     "teal": [0, 128, 128],
60     "dark_cyan": [0, 139, 139],
61     "aqua": [0, 255, 255],
62     "cyan": [0, 255, 255],
63     "light_cyan": [224, 255, 255],
64     "dark_turquoise": [0, 206, 209],
65     "turquoise": [64, 224, 208],
66     "medium_turquoise": [72, 209, 204],
67     "pale_turquoise": [175, 238, 238],
68     "aqua_marine": [127, 255, 212],
69     "powder_blue": [176, 224, 230],
70     "cadet_blue": [95, 158, 160],
71     "steel_blue": [70, 130, 180],
72     "corn_flower_blue": [100, 149, 237],
73     "deep_sky_blue": [0, 191, 255],
74     "dodger_blue": [30, 144, 255],
75     "light_blue": [173, 216, 230],
76     "sky_blue": [135, 206, 235],
77     "light_sky_blue": [135, 206, 250],
78     "midnight_blue": [25, 25, 112],
79     "navy": [0, 0, 128],
80     "dark_blue": [0, 0, 139],
81     "medium_blue": [0, 0, 205],
82     "royal_blue": [65, 105, 225],
83     "blue_violet": [138, 43, 226],
84     "indigo": [75, 0, 130],
85     "dark_slate_blue": [72, 61, 139],
86     "slate_blue": [106, 90, 205],
87     "medium_slate_blue": [123, 104, 238],
88     "medium_purple": [147, 112, 219],
89     "dark_magenta": [139, 0, 139],
90     "dark_violet": [148, 0, 211],
91     "dark_orchid": [153, 50, 204],
92     "medium_orchid": [186, 85, 211],
93     "purple": [128, 0, 128],
94     "thistle": [216, 191, 216],
95     "plum": [221, 160, 221],
96     "violet": [238, 130, 238],
97     "magenta": [255, 0, 255],
98     "orchid": [218, 112, 214],
99     "medium_violet_red": [199, 21, 133],
100     "pale_violet_red": [219, 112, 147],
101     "deep_pink": [255, 20, 147],
102     "hot_pink": [255, 105, 180],
103     "light_pink": [255, 182, 193],
104     "pink": [255, 192, 203],
105     "antique_white": [250, 235, 215],
106     "beige": [245, 245, 220],
107     "bisque": [255, 228, 196],
108     "blanched_almond": [255, 235, 205],
109     "wheat": [245, 222, 179],
110     "corn_silk": [255, 248, 220],
111     "lemon_chiffon": [255, 250, 205],
112     "light_golden_rod_yellow": [250, 250, 210],
113     "light_yellow": [255, 255, 224],
114     "saddle_brown": [139, 69, 19],
115     "sienna": [160, 82, 45],
116     "chocolate": [210, 105, 30],
117     "peru": [205, 133, 63],
118     "sandy_brown": [244, 164, 96],
119     "burly_wood": [222, 184, 135],
120     "tan": [210, 180, 140],
121     "rosy_brown": [188, 143, 143],
122     "moccasin": [255, 228, 181],
123     "navajo_white": [255, 222, 173],
124     "peach_puff": [255, 218, 185],
125     "misty_rose": [255, 228, 225],
126     "lavender_blush": [255, 240, 245],
127     "linen": [250, 240, 230],
128     "old_lace": [253, 245, 230],
129     "papaya_whip": [255, 239, 213],
130     "sea_shell": [255, 245, 238],
131     "mint_cream": [245, 255, 250],
132     "slate_gray": [112, 128, 144],
133     "light_slate_gray": [119, 136, 153],
134     "light_steel_blue": [176, 196, 222],
135     "lavender": [230, 230, 250],
136     "floral_white": [255, 250, 240],
137     "alice_blue": [240, 248, 255],
138     "ghost_white": [248, 248, 255],
139     "honeydew": [240, 255, 240],
140     "ivory": [255, 255, 240],
141     "azure": [240, 255, 255],
142     "snow": [255, 250, 250],
143     "silver": [192, 192, 192],
144     "gainsboro": [220, 220, 220],
145     "white_smoke": [245, 245, 245],
146 }
147
148 color_id = dict([(n, k) for k, n in enumerate(color_tokens.keys())])
149 color_names = dict([(k, n) for k, n in enumerate(color_tokens.keys())])
150
151 ######################################################################
152
153
154 def all_properties(height, width, nb_squares, square_i, square_j, square_c):
155     s = []
156
157     for r, c in [(k, color_names[square_c[k].item()]) for k in range(nb_squares)]:
158         s += [f"there is {c}"]
159
160         if square_i[r] >= height - height // 3:
161             s += [f"{c} bottom"]
162         if square_i[r] < height // 3:
163             s += [f"{c} top"]
164         if square_j[r] >= width - width // 3:
165             s += [f"{c} right"]
166         if square_j[r] < width // 3:
167             s += [f"{c} left"]
168
169         for t, d in [(k, color_names[square_c[k].item()]) for k in range(nb_squares)]:
170             if square_i[r] > square_i[t]:
171                 s += [f"{c} below {d}"]
172             if square_i[r] < square_i[t]:
173                 s += [f"{c} above {d}"]
174             if square_j[r] > square_j[t]:
175                 s += [f"{c} right of {d}"]
176             if square_j[r] < square_j[t]:
177                 s += [f"{c} left of {d}"]
178
179     return s
180
181
182 ######################################################################
183
184
185 def generate(
186     nb,
187     height,
188     width,
189     max_nb_squares=5,
190     max_nb_properties=10,
191     nb_colors=5,
192     pruning_criterion=None,
193 ):
194
195     assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
196
197     descr = []
198
199     for n in range(nb):
200
201         nb_squares = torch.randint(max_nb_squares, (1,)) + 1
202         square_position = torch.randperm(height * width)[:nb_squares]
203         # color 0 is white and reserved for the background
204         square_c = torch.randperm(nb_colors)[:nb_squares] + 1
205         square_i = square_position.div(width, rounding_mode="floor")
206         square_j = square_position % width
207
208         img = torch.zeros(height * width, dtype=torch.int64)
209         for k in range(nb_squares):
210             img[square_position[k]] = square_c[k]
211
212         # generates all the true properties
213
214         s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
215
216         if pruning_criterion is not None:
217             s = list(filter(pruning_criterion, s))
218
219         # pick at most max_nb_properties at random
220
221         nb_properties = torch.randint(max_nb_properties, (1,)) + 1
222         s = " <sep> ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]])
223         s += " <img> " + " ".join([f"{color_names[n.item()]}" for n in img])
224
225         descr += [s]
226
227     return descr
228
229
230 ######################################################################
231
232
233 def descr2img(descr, height, width):
234
235     if type(descr) == list:
236         return torch.cat([descr2img(d, height, width) for d in descr], 0)
237
238     def token2color(t):
239         try:
240             return color_tokens[t]
241         except KeyError:
242             return [128, 128, 128]
243
244     d = descr.split("<img>", 1)
245     d = d[-1] if len(d) > 1 else ""
246     d = d.strip().split(" ")[: height * width]
247     d = d + ["<unk>"] * (height * width - len(d))
248     d = [token2color(t) for t in d]
249     img = torch.tensor(d).permute(1, 0)
250     img = img.reshape(1, 3, height, width)
251
252     return img
253
254
255 ######################################################################
256
257
258 def descr2properties(descr, height, width):
259
260     if type(descr) == list:
261         return [descr2properties(d, height, width) for d in descr]
262
263     d = descr.split("<img>", 1)
264     d = d[-1] if len(d) > 1 else ""
265     d = d.strip().split(" ")[: height * width]
266
267     seen = {}
268     if len(d) != height * width:
269         return []
270
271     for k, x in enumerate(d):
272         if x != color_names[0]:
273             if x in color_tokens:
274                 if x in seen:
275                     return []
276             else:
277                 return []
278             seen[x] = (color_id[x], k // width, k % width)
279
280     square_infos = tuple(zip(*seen.values()))
281     if square_infos:
282         square_c = torch.tensor(square_infos[0])
283         square_i = torch.tensor(square_infos[1])
284         square_j = torch.tensor(square_infos[2])
285     else:
286         square_c = torch.tensor([])
287         square_i = torch.tensor([])
288         square_j = torch.tensor([])
289
290     s = all_properties(height, width, len(seen), square_i, square_j, square_c)
291
292     return s
293
294
295 ######################################################################
296
297
298 def nb_properties(descr, height, width):
299     if type(descr) == list:
300         return [nb_properties(d, height, width) for d in descr]
301
302     d = descr.split("<img>", 1)
303     if len(d) == 0:
304         return 0
305     d = d[0].strip().split("<sep>")
306     d = [x.strip() for x in d]
307
308     requested_properties = set(d)
309     all_properties = set(descr2properties(descr, height, width))
310     missing_properties = requested_properties - all_properties
311
312     return (len(requested_properties), len(all_properties), len(missing_properties))
313
314
315 ######################################################################
316
317 if __name__ == "__main__":
318     descr = generate(
319         nb=5,
320         height=12,
321         width=16,
322         pruning_criterion=lambda s: not (
323             "green" in s and ("right" in s or "left" in s)
324         ),
325     )
326
327     print(descr2properties(descr, height=12, width=16))
328     print(nb_properties(descr, height=12, width=16))
329
330     with open("picoclvr_example.txt", "w") as f:
331         for d in descr:
332             f.write(f"{d}\n\n")
333
334     img = descr2img(descr, height=12, width=16)
335     torchvision.utils.save_image(
336         img / 255.0, "picoclvr_example.png", nrow=16, pad_value=0.8
337     )
338
339     import time
340
341     start_time = time.perf_counter()
342     descr = generate(nb=1000, height=12, width=16)
343     end_time = time.perf_counter()
344     print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
345
346 ######################################################################