X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picoclvr.py;h=0cd306243fd79ee90728eb16340581aad077914f;hb=e39282eef52a7f5ab6654b999009127569b1b599;hp=94c0f88e71b5685f433c94b3e47252f0bb3bfafa;hpb=943a440a83b98de60bad767a9ad09f63b5088514;p=picoclvr.git diff --git a/picoclvr.py b/picoclvr.py index 94c0f88..0cd3062 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -5,287 +5,150 @@ # Written by Francois Fleuret +import math import torch, torchvision import torch.nn.functional as F -colors = [ - [255, 255, 255], - [255, 0, 0], - [0, 128, 0], - [0, 0, 255], - [255, 255, 0], - [0, 0, 0], - [128, 0, 0], - [139, 0, 0], - [165, 42, 42], - [178, 34, 34], - [220, 20, 60], - [255, 99, 71], - [255, 127, 80], - [205, 92, 92], - [240, 128, 128], - [233, 150, 122], - [250, 128, 114], - [255, 160, 122], - [255, 69, 0], - [255, 140, 0], - [255, 165, 0], - [255, 215, 0], - [184, 134, 11], - [218, 165, 32], - [238, 232, 170], - [189, 183, 107], - [240, 230, 140], - [128, 128, 0], - [154, 205, 50], - [85, 107, 47], - [107, 142, 35], - [124, 252, 0], - [127, 255, 0], - [173, 255, 47], - [0, 100, 0], - [34, 139, 34], - [0, 255, 0], - [50, 205, 50], - [144, 238, 144], - [152, 251, 152], - [143, 188, 143], - [0, 250, 154], - [0, 255, 127], - [46, 139, 87], - [102, 205, 170], - [60, 179, 113], - [32, 178, 170], - [47, 79, 79], - [0, 128, 128], - [0, 139, 139], - [0, 255, 255], - [0, 255, 255], - [224, 255, 255], - [0, 206, 209], - [64, 224, 208], - [72, 209, 204], - [175, 238, 238], - [127, 255, 212], - [176, 224, 230], - [95, 158, 160], - [70, 130, 180], - [100, 149, 237], - [0, 191, 255], - [30, 144, 255], - [173, 216, 230], - [135, 206, 235], - [135, 206, 250], - [25, 25, 112], - [0, 0, 128], - [0, 0, 139], - [0, 0, 205], - [65, 105, 225], - [138, 43, 226], - [75, 0, 130], - [72, 61, 139], - [106, 90, 205], - [123, 104, 238], - [147, 112, 219], - [139, 0, 139], - [148, 0, 211], - [153, 50, 204], - [186, 85, 211], - [128, 0, 128], - [216, 191, 216], - [221, 160, 221], - [238, 130, 238], - [255, 0, 255], - [218, 112, 214], - [199, 21, 133], - [219, 112, 147], - [255, 20, 147], - [255, 105, 180], - [255, 182, 193], - [255, 192, 203], - [250, 235, 215], - [245, 245, 220], - [255, 228, 196], - [255, 235, 205], - [245, 222, 179], - [255, 248, 220], - [255, 250, 205], - [250, 250, 210], - [255, 255, 224], - [139, 69, 19], - [160, 82, 45], - [210, 105, 30], - [205, 133, 63], - [244, 164, 96], - [222, 184, 135], - [210, 180, 140], - [188, 143, 143], - [255, 228, 181], - [255, 222, 173], - [255, 218, 185], - [255, 228, 225], - [255, 240, 245], - [250, 240, 230], - [253, 245, 230], - [255, 239, 213], - [255, 245, 238], - [245, 255, 250], - [112, 128, 144], - [119, 136, 153], - [176, 196, 222], - [230, 230, 250], - [255, 250, 240], - [240, 248, 255], - [248, 248, 255], - [240, 255, 240], - [255, 255, 240], - [240, 255, 255], - [255, 250, 250], - [192, 192, 192], - [220, 220, 220], - [245, 245, 245], -] - -color_names = [ - "white", - "red", - "green", - "blue", - "yellow", - "black", - "maroon", - "dark_red", - "brown", - "firebrick", - "crimson", - "tomato", - "coral", - "indian_red", - "light_coral", - "dark_salmon", - "salmon", - "light_salmon", - "orange_red", - "dark_orange", - "orange", - "gold", - "dark_golden_rod", - "golden_rod", - "pale_golden_rod", - "dark_khaki", - "khaki", - "olive", - "yellow_green", - "dark_olive_green", - "olive_drab", - "lawn_green", - "chartreuse", - "green_yellow", - "dark_green", - "forest_green", - "lime", - "lime_green", - "light_green", - "pale_green", - "dark_sea_green", - "medium_spring_green", - "spring_green", - "sea_green", - "medium_aqua_marine", - "medium_sea_green", - "light_sea_green", - "dark_slate_gray", - "teal", - "dark_cyan", - "aqua", - "cyan", - "light_cyan", - "dark_turquoise", - "turquoise", - "medium_turquoise", - "pale_turquoise", - "aqua_marine", - "powder_blue", - "cadet_blue", - "steel_blue", - "corn_flower_blue", - "deep_sky_blue", - "dodger_blue", - "light_blue", - "sky_blue", - "light_sky_blue", - "midnight_blue", - "navy", - "dark_blue", - "medium_blue", - "royal_blue", - "blue_violet", - "indigo", - "dark_slate_blue", - "slate_blue", - "medium_slate_blue", - "medium_purple", - "dark_magenta", - "dark_violet", - "dark_orchid", - "medium_orchid", - "purple", - "thistle", - "plum", - "violet", - "magenta", - "orchid", - "medium_violet_red", - "pale_violet_red", - "deep_pink", - "hot_pink", - "light_pink", - "pink", - "antique_white", - "beige", - "bisque", - "blanched_almond", - "wheat", - "corn_silk", - "lemon_chiffon", - "light_golden_rod_yellow", - "light_yellow", - "saddle_brown", - "sienna", - "chocolate", - "peru", - "sandy_brown", - "burly_wood", - "tan", - "rosy_brown", - "moccasin", - "navajo_white", - "peach_puff", - "misty_rose", - "lavender_blush", - "linen", - "old_lace", - "papaya_whip", - "sea_shell", - "mint_cream", - "slate_gray", - "light_slate_gray", - "light_steel_blue", - "lavender", - "floral_white", - "alice_blue", - "ghost_white", - "honeydew", - "ivory", - "azure", - "snow", - "silver", - "gainsboro", - "white_smoke", -] - -color_id = dict([(n, k) for k, n in enumerate(color_names)]) -color_tokens = dict([(n, c) for n, c in zip(color_names, colors)]) +color_name2rgb = { + "white": [255, 255, 255], + "red": [255, 0, 0], + "green": [0, 128, 0], + "blue": [0, 0, 255], + "yellow": [255, 255, 0], + "black": [0, 0, 0], + "maroon": [128, 0, 0], + "dark_red": [139, 0, 0], + "brown": [165, 42, 42], + "firebrick": [178, 34, 34], + "crimson": [220, 20, 60], + "tomato": [255, 99, 71], + "coral": [255, 127, 80], + "indian_red": [205, 92, 92], + "light_coral": [240, 128, 128], + "dark_salmon": [233, 150, 122], + "salmon": [250, 128, 114], + "light_salmon": [255, 160, 122], + "orange_red": [255, 69, 0], + "dark_orange": [255, 140, 0], + "orange": [255, 165, 0], + "gold": [255, 215, 0], + "dark_golden_rod": [184, 134, 11], + "golden_rod": [218, 165, 32], + "pale_golden_rod": [238, 232, 170], + "dark_khaki": [189, 183, 107], + "khaki": [240, 230, 140], + "olive": [128, 128, 0], + "yellow_green": [154, 205, 50], + "dark_olive_green": [85, 107, 47], + "olive_drab": [107, 142, 35], + "lawn_green": [124, 252, 0], + "chartreuse": [127, 255, 0], + "green_yellow": [173, 255, 47], + "dark_green": [0, 100, 0], + "forest_green": [34, 139, 34], + "lime": [0, 255, 0], + "lime_green": [50, 205, 50], + "light_green": [144, 238, 144], + "pale_green": [152, 251, 152], + "dark_sea_green": [143, 188, 143], + "medium_spring_green": [0, 250, 154], + "spring_green": [0, 255, 127], + "sea_green": [46, 139, 87], + "medium_aqua_marine": [102, 205, 170], + "medium_sea_green": [60, 179, 113], + "light_sea_green": [32, 178, 170], + "dark_slate_gray": [47, 79, 79], + "teal": [0, 128, 128], + "dark_cyan": [0, 139, 139], + "aqua": [0, 255, 255], + "cyan": [0, 255, 255], + "light_cyan": [224, 255, 255], + "dark_turquoise": [0, 206, 209], + "turquoise": [64, 224, 208], + "medium_turquoise": [72, 209, 204], + "pale_turquoise": [175, 238, 238], + "aqua_marine": [127, 255, 212], + "powder_blue": [176, 224, 230], + "cadet_blue": [95, 158, 160], + "steel_blue": [70, 130, 180], + "corn_flower_blue": [100, 149, 237], + "deep_sky_blue": [0, 191, 255], + "dodger_blue": [30, 144, 255], + "light_blue": [173, 216, 230], + "sky_blue": [135, 206, 235], + "light_sky_blue": [135, 206, 250], + "midnight_blue": [25, 25, 112], + "navy": [0, 0, 128], + "dark_blue": [0, 0, 139], + "medium_blue": [0, 0, 205], + "royal_blue": [65, 105, 225], + "blue_violet": [138, 43, 226], + "indigo": [75, 0, 130], + "dark_slate_blue": [72, 61, 139], + "slate_blue": [106, 90, 205], + "medium_slate_blue": [123, 104, 238], + "medium_purple": [147, 112, 219], + "dark_magenta": [139, 0, 139], + "dark_violet": [148, 0, 211], + "dark_orchid": [153, 50, 204], + "medium_orchid": [186, 85, 211], + "purple": [128, 0, 128], + "thistle": [216, 191, 216], + "plum": [221, 160, 221], + "violet": [238, 130, 238], + "magenta": [255, 0, 255], + "orchid": [218, 112, 214], + "medium_violet_red": [199, 21, 133], + "pale_violet_red": [219, 112, 147], + "deep_pink": [255, 20, 147], + "hot_pink": [255, 105, 180], + "light_pink": [255, 182, 193], + "pink": [255, 192, 203], + "antique_white": [250, 235, 215], + "beige": [245, 245, 220], + "bisque": [255, 228, 196], + "blanched_almond": [255, 235, 205], + "wheat": [245, 222, 179], + "corn_silk": [255, 248, 220], + "lemon_chiffon": [255, 250, 205], + "light_golden_rod_yellow": [250, 250, 210], + "light_yellow": [255, 255, 224], + "saddle_brown": [139, 69, 19], + "sienna": [160, 82, 45], + "chocolate": [210, 105, 30], + "peru": [205, 133, 63], + "sandy_brown": [244, 164, 96], + "burly_wood": [222, 184, 135], + "tan": [210, 180, 140], + "rosy_brown": [188, 143, 143], + "moccasin": [255, 228, 181], + "navajo_white": [255, 222, 173], + "peach_puff": [255, 218, 185], + "misty_rose": [255, 228, 225], + "lavender_blush": [255, 240, 245], + "linen": [250, 240, 230], + "old_lace": [253, 245, 230], + "papaya_whip": [255, 239, 213], + "sea_shell": [255, 245, 238], + "mint_cream": [245, 255, 250], + "slate_gray": [112, 128, 144], + "light_slate_gray": [119, 136, 153], + "light_steel_blue": [176, 196, 222], + "lavender": [230, 230, 250], + "floral_white": [255, 250, 240], + "alice_blue": [240, 248, 255], + "ghost_white": [248, 248, 255], + "honeydew": [240, 255, 240], + "ivory": [255, 255, 240], + "azure": [240, 255, 255], + "snow": [255, 250, 250], + "silver": [192, 192, 192], + "gainsboro": [220, 220, 220], + "white_smoke": [245, 245, 245], +} + +color_name2id = dict([(n, k) for k, n in enumerate(color_name2rgb.keys())]) +color_id2name = dict([(k, n) for k, n in enumerate(color_name2rgb.keys())]) ###################################################################### @@ -293,7 +156,7 @@ color_tokens = dict([(n, c) for n, c in zip(color_names, colors)]) def all_properties(height, width, nb_squares, square_i, square_j, square_c): s = [] - for r, c_r in [(k, color_names[square_c[k]]) for k in range(nb_squares)]: + for r, c_r in [(k, color_id2name[square_c[k].item()]) for k in range(nb_squares)]: s += [f"there is {c_r}"] if square_i[r] >= height - height // 3: @@ -305,7 +168,9 @@ def all_properties(height, width, nb_squares, square_i, square_j, square_c): if square_j[r] < width // 3: s += [f"{c_r} left"] - for t, c_t in [(k, color_names[square_c[k]]) for k in range(nb_squares)]: + for t, c_t in [ + (k, color_id2name[square_c[k].item()]) for k in range(nb_squares) + ]: if square_i[r] > square_i[t]: s += [f"{c_r} below {c_t}"] if square_i[r] < square_i[t]: @@ -332,14 +197,17 @@ def generate( nb_colors=5, pruner=None, ): - - assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1 + assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1 descr = [] for n in range(nb): - - nb_squares = torch.randint(max_nb_squares, (1,)) + 1 + # we want uniform over the combinations of 1 to max_nb_squares + # pixels of nb_colors + logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float() + dist = torch.distributions.categorical.Categorical(logits=logits) + nb_squares = dist.sample((1,)) + 1 + # nb_squares = torch.randint(max_nb_squares, (1,)) + 1 square_position = torch.randperm(height * width)[:nb_squares] # color 0 is white and reserved for the background @@ -347,7 +215,7 @@ def generate( square_i = square_position.div(width, rounding_mode="floor") square_j = square_position % width - img = [0] * height * width + img = torch.zeros(height * width, dtype=torch.int64) for k in range(nb_squares): img[square_position[k]] = square_c[k] @@ -364,7 +232,7 @@ def generate( s = ( " ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]]) + " " - + " ".join([f"{color_names[n]}" for n in img]) + + " ".join([f"{color_id2name[n.item()]}" for n in img]) ) descr += [s] @@ -377,31 +245,24 @@ def generate( # Extracts the image after in descr as a 1x3xHxW tensor -def descr2img(descr, n, height, width): - - if type(descr) == list: - return torch.cat([descr2img(d, n, height, width) for d in descr], 0) - - if type(n) == list: - return torch.cat([descr2img(descr, k, height, width) for k in n], 0).unsqueeze( - 0 - ) +def descr2img(descr, height, width): + result = [] def token2color(t): try: - return color_tokens[t] + return color_name2rgb[t] except KeyError: return [128, 128, 128] - d = descr.split("") - d = d[n + 1] if len(d) > n + 1 else "" - d = d.strip().split(" ")[: height * width] - d = d + [""] * (height * width - len(d)) - d = [token2color(t) for t in d] - img = torch.tensor(d).permute(1, 0) - img = img.reshape(1, 3, height, width) + for d in descr: + d = d.split("")[1] + d = d.strip().split(" ")[: height * width] + d = d + [""] * (height * width - len(d)) + d = [token2color(t) for t in d] + img = torch.tensor(d).permute(1, 0).reshape(1, 3, height, width) + result.append(img) - return img + return torch.cat(result, 0) ###################################################################### @@ -410,25 +271,24 @@ def descr2img(descr, n, height, width): def descr2properties(descr, height, width): - if type(descr) == list: return [descr2properties(d, height, width) for d in descr] d = descr.split("") - d = d[-1] if len(d) > 1 else "" - d = d.strip().split(" ")[: height * width] - if len(d) != height * width: + img_tokens = d[-1] if len(d) > 1 else "" + img_tokens = img_tokens.strip().split(" ")[: height * width] + if len(img_tokens) != height * width: return [] seen = {} - for k, x in enumerate(d): - if x != color_names[0]: - if x in color_tokens: + for k, x in enumerate(img_tokens): + if x != color_id2name[0]: + if x in color_name2rgb: if x in seen: return [] else: return [] - seen[x] = (color_id[x], k // width, k % width) + seen[x] = (color_name2id[x], k // width, k % width) square_infos = tuple(zip(*seen.values())) @@ -455,7 +315,6 @@ def descr2properties(descr, height, width): def nb_properties(descr, height, width, pruner=None): - if type(descr) == list: return [nb_properties(d, height, width, pruner) for d in descr] @@ -489,7 +348,7 @@ if __name__ == "__main__": for d in descr: f.write(f"{d}\n\n") - img = descr2img(descr, n=0, height=12, width=16) + img = descr2img(descr, height=12, width=16) if img.size(0) == 1: img = F.pad(img, (1, 1, 1, 1), value=64)