import torch, torchvision
colors = [
- [ 255, 255, 255 ],
- [ 255, 0, 0 ],
- [ 0, 128, 0 ],
- [ 0, 0, 255 ],
- [ 255, 255, 0 ],
- [ 0, 0, 0 ],
- [ 128, 0, 0 ],
- [ 139, 0, 0 ],
- [ 165, 42, 42 ],
- [ 178, 34, 34 ],
- [ 220, 20, 60 ],
- [ 255, 99, 71 ],
- [ 255, 127, 80 ],
- [ 205, 92, 92 ],
- [ 240, 128, 128 ],
- [ 233, 150, 122 ],
- [ 250, 128, 114 ],
- [ 255, 160, 122 ],
- [ 255, 69, 0 ],
- [ 255, 140, 0 ],
- [ 255, 165, 0 ],
- [ 255, 215, 0 ],
- [ 184, 134, 11 ],
- [ 218, 165, 32 ],
- [ 238, 232, 170 ],
- [ 189, 183, 107 ],
- [ 240, 230, 140 ],
- [ 128, 128, 0 ],
- [ 154, 205, 50 ],
- [ 85, 107, 47 ],
- [ 107, 142, 35 ],
- [ 124, 252, 0 ],
- [ 127, 255, 0 ],
- [ 173, 255, 47 ],
- [ 0, 100, 0 ],
- [ 34, 139, 34 ],
- [ 0, 255, 0 ],
- [ 50, 205, 50 ],
- [ 144, 238, 144 ],
- [ 152, 251, 152 ],
- [ 143, 188, 143 ],
- [ 0, 250, 154 ],
- [ 0, 255, 127 ],
- [ 46, 139, 87 ],
- [ 102, 205, 170 ],
- [ 60, 179, 113 ],
- [ 32, 178, 170 ],
- [ 47, 79, 79 ],
- [ 0, 128, 128 ],
- [ 0, 139, 139 ],
- [ 0, 255, 255 ],
- [ 0, 255, 255 ],
- [ 224, 255, 255 ],
- [ 0, 206, 209 ],
- [ 64, 224, 208 ],
- [ 72, 209, 204 ],
- [ 175, 238, 238 ],
- [ 127, 255, 212 ],
- [ 176, 224, 230 ],
- [ 95, 158, 160 ],
- [ 70, 130, 180 ],
- [ 100, 149, 237 ],
- [ 0, 191, 255 ],
- [ 30, 144, 255 ],
- [ 173, 216, 230 ],
- [ 135, 206, 235 ],
- [ 135, 206, 250 ],
- [ 25, 25, 112 ],
- [ 0, 0, 128 ],
- [ 0, 0, 139 ],
- [ 0, 0, 205 ],
- [ 65, 105, 225 ],
- [ 138, 43, 226 ],
- [ 75, 0, 130 ],
- [ 72, 61, 139 ],
- [ 106, 90, 205 ],
- [ 123, 104, 238 ],
- [ 147, 112, 219 ],
- [ 139, 0, 139 ],
- [ 148, 0, 211 ],
- [ 153, 50, 204 ],
- [ 186, 85, 211 ],
- [ 128, 0, 128 ],
- [ 216, 191, 216 ],
- [ 221, 160, 221 ],
- [ 238, 130, 238 ],
- [ 255, 0, 255 ],
- [ 218, 112, 214 ],
- [ 199, 21, 133 ],
- [ 219, 112, 147 ],
- [ 255, 20, 147 ],
- [ 255, 105, 180 ],
- [ 255, 182, 193 ],
- [ 255, 192, 203 ],
- [ 250, 235, 215 ],
- [ 245, 245, 220 ],
- [ 255, 228, 196 ],
- [ 255, 235, 205 ],
- [ 245, 222, 179 ],
- [ 255, 248, 220 ],
- [ 255, 250, 205 ],
- [ 250, 250, 210 ],
- [ 255, 255, 224 ],
- [ 139, 69, 19 ],
- [ 160, 82, 45 ],
- [ 210, 105, 30 ],
- [ 205, 133, 63 ],
- [ 244, 164, 96 ],
- [ 222, 184, 135 ],
- [ 210, 180, 140 ],
- [ 188, 143, 143 ],
- [ 255, 228, 181 ],
- [ 255, 222, 173 ],
- [ 255, 218, 185 ],
- [ 255, 228, 225 ],
- [ 255, 240, 245 ],
- [ 250, 240, 230 ],
- [ 253, 245, 230 ],
- [ 255, 239, 213 ],
- [ 255, 245, 238 ],
- [ 245, 255, 250 ],
- [ 112, 128, 144 ],
- [ 119, 136, 153 ],
- [ 176, 196, 222 ],
- [ 230, 230, 250 ],
- [ 255, 250, 240 ],
- [ 240, 248, 255 ],
- [ 248, 248, 255 ],
- [ 240, 255, 240 ],
- [ 255, 255, 240 ],
- [ 240, 255, 255 ],
- [ 255, 250, 250 ],
- [ 192, 192, 192 ],
- [ 220, 220, 220 ],
- [ 245, 245, 245 ],
+ [ 255, 255, 255 ], [ 255, 0, 0 ], [ 0, 128, 0 ], [ 0, 0, 255 ], [ 255, 255, 0 ],
+ [ 0, 0, 0 ], [ 128, 0, 0 ], [ 139, 0, 0 ], [ 165, 42, 42 ], [ 178, 34, 34 ],
+ [ 220, 20, 60 ], [ 255, 99, 71 ], [ 255, 127, 80 ], [ 205, 92, 92 ], [ 240, 128, 128 ],
+ [ 233, 150, 122 ], [ 250, 128, 114 ], [ 255, 160, 122 ], [ 255, 69, 0 ], [ 255, 140, 0 ],
+ [ 255, 165, 0 ], [ 255, 215, 0 ], [ 184, 134, 11 ], [ 218, 165, 32 ], [ 238, 232, 170 ],
+ [ 189, 183, 107 ], [ 240, 230, 140 ], [ 128, 128, 0 ], [ 154, 205, 50 ], [ 85, 107, 47 ],
+ [ 107, 142, 35 ], [ 124, 252, 0 ], [ 127, 255, 0 ], [ 173, 255, 47 ], [ 0, 100, 0 ],
+ [ 34, 139, 34 ], [ 0, 255, 0 ], [ 50, 205, 50 ], [ 144, 238, 144 ], [ 152, 251, 152 ],
+ [ 143, 188, 143 ], [ 0, 250, 154 ], [ 0, 255, 127 ], [ 46, 139, 87 ], [ 102, 205, 170 ],
+ [ 60, 179, 113 ], [ 32, 178, 170 ], [ 47, 79, 79 ], [ 0, 128, 128 ], [ 0, 139, 139 ],
+ [ 0, 255, 255 ], [ 0, 255, 255 ], [ 224, 255, 255 ], [ 0, 206, 209 ], [ 64, 224, 208 ],
+ [ 72, 209, 204 ], [ 175, 238, 238 ], [ 127, 255, 212 ], [ 176, 224, 230 ], [ 95, 158, 160 ],
+ [ 70, 130, 180 ], [ 100, 149, 237 ], [ 0, 191, 255 ], [ 30, 144, 255 ], [ 173, 216, 230 ],
+ [ 135, 206, 235 ], [ 135, 206, 250 ], [ 25, 25, 112 ], [ 0, 0, 128 ], [ 0, 0, 139 ],
+ [ 0, 0, 205 ], [ 65, 105, 225 ], [ 138, 43, 226 ], [ 75, 0, 130 ], [ 72, 61, 139 ],
+ [ 106, 90, 205 ], [ 123, 104, 238 ], [ 147, 112, 219 ], [ 139, 0, 139 ], [ 148, 0, 211 ],
+ [ 153, 50, 204 ], [ 186, 85, 211 ], [ 128, 0, 128 ], [ 216, 191, 216 ], [ 221, 160, 221 ],
+ [ 238, 130, 238 ], [ 255, 0, 255 ], [ 218, 112, 214 ], [ 199, 21, 133 ], [ 219, 112, 147 ],
+ [ 255, 20, 147 ], [ 255, 105, 180 ], [ 255, 182, 193 ], [ 255, 192, 203 ], [ 250, 235, 215 ],
+ [ 245, 245, 220 ], [ 255, 228, 196 ], [ 255, 235, 205 ], [ 245, 222, 179 ], [ 255, 248, 220 ],
+ [ 255, 250, 205 ], [ 250, 250, 210 ], [ 255, 255, 224 ], [ 139, 69, 19 ], [ 160, 82, 45 ],
+ [ 210, 105, 30 ], [ 205, 133, 63 ], [ 244, 164, 96 ], [ 222, 184, 135 ], [ 210, 180, 140 ],
+ [ 188, 143, 143 ], [ 255, 228, 181 ], [ 255, 222, 173 ], [ 255, 218, 185 ], [ 255, 228, 225 ],
+ [ 255, 240, 245 ], [ 250, 240, 230 ], [ 253, 245, 230 ], [ 255, 239, 213 ], [ 255, 245, 238 ],
+ [ 245, 255, 250 ], [ 112, 128, 144 ], [ 119, 136, 153 ], [ 176, 196, 222 ], [ 230, 230, 250 ],
+ [ 255, 250, 240 ], [ 240, 248, 255 ], [ 248, 248, 255 ], [ 240, 255, 240 ], [ 255, 255, 240 ],
+ [ 240, 255, 255 ], [ 255, 250, 250 ], [ 192, 192, 192 ], [ 220, 220, 220 ], [ 245, 245, 245 ],
]
color_names = [
- 'white',
- 'red',
- 'green',
- 'blue',
- 'yellow',
- 'black',
- 'maroon',
- 'dark_red',
- 'brown',
- 'firebrick',
- 'crimson',
- 'tomato',
- 'coral',
- 'indian_red',
- 'light_coral',
- 'dark_salmon',
- 'salmon',
- 'light_salmon',
- 'orange_red',
- 'dark_orange',
- 'orange',
- 'gold',
- 'dark_golden_rod',
- 'golden_rod',
- 'pale_golden_rod',
- 'dark_khaki',
- 'khaki',
- 'olive',
- 'yellow_green',
- 'dark_olive_green',
- 'olive_drab',
- 'lawn_green',
- 'chartreuse',
- 'green_yellow',
- 'dark_green',
- 'forest_green',
- 'lime',
- 'lime_green',
- 'light_green',
- 'pale_green',
- 'dark_sea_green',
- 'medium_spring_green',
- 'spring_green',
- 'sea_green',
- 'medium_aqua_marine',
- 'medium_sea_green',
- 'light_sea_green',
- 'dark_slate_gray',
- 'teal',
- 'dark_cyan',
- 'aqua',
- 'cyan',
- 'light_cyan',
- 'dark_turquoise',
- 'turquoise',
- 'medium_turquoise',
- 'pale_turquoise',
- 'aqua_marine',
- 'powder_blue',
- 'cadet_blue',
- 'steel_blue',
- 'corn_flower_blue',
- 'deep_sky_blue',
- 'dodger_blue',
- 'light_blue',
- 'sky_blue',
- 'light_sky_blue',
- 'midnight_blue',
- 'navy',
- 'dark_blue',
- 'medium_blue',
- 'royal_blue',
- 'blue_violet',
- 'indigo',
- 'dark_slate_blue',
- 'slate_blue',
- 'medium_slate_blue',
- 'medium_purple',
- 'dark_magenta',
- 'dark_violet',
- 'dark_orchid',
- 'medium_orchid',
- 'purple',
- 'thistle',
- 'plum',
- 'violet',
- 'magenta',
- 'orchid',
- 'medium_violet_red',
- 'pale_violet_red',
- 'deep_pink',
- 'hot_pink',
- 'light_pink',
- 'pink',
- 'antique_white',
- 'beige',
- 'bisque',
- 'blanched_almond',
- 'wheat',
- 'corn_silk',
- 'lemon_chiffon',
- 'light_golden_rod_yellow',
- 'light_yellow',
- 'saddle_brown',
- 'sienna',
- 'chocolate',
- 'peru',
- 'sandy_brown',
- 'burly_wood',
- 'tan',
- 'rosy_brown',
- 'moccasin',
- 'navajo_white',
- 'peach_puff',
- 'misty_rose',
- 'lavender_blush',
- 'linen',
- 'old_lace',
- 'papaya_whip',
- 'sea_shell',
- 'mint_cream',
- 'slate_gray',
- 'light_slate_gray',
- 'light_steel_blue',
- 'lavender',
- 'floral_white',
- 'alice_blue',
- 'ghost_white',
- 'honeydew',
- 'ivory',
- 'azure',
- 'snow',
- 'silver',
- 'gainsboro',
- 'white_smoke',
+ 'white', 'red', 'green', 'blue', 'yellow',
+ 'black', 'maroon', 'dark_red', 'brown', 'firebrick',
+ 'crimson', 'tomato', 'coral', 'indian_red', 'light_coral',
+ 'dark_salmon', 'salmon', 'light_salmon', 'orange_red', 'dark_orange',
+ 'orange', 'gold', 'dark_golden_rod', 'golden_rod', 'pale_golden_rod',
+ 'dark_khaki', 'khaki', 'olive', 'yellow_green', 'dark_olive_green',
+ 'olive_drab', 'lawn_green', 'chartreuse', 'green_yellow', 'dark_green',
+ 'forest_green', 'lime', 'lime_green', 'light_green', 'pale_green',
+ 'dark_sea_green', 'medium_spring_green', 'spring_green', 'sea_green', 'medium_aqua_marine',
+ 'medium_sea_green', 'light_sea_green', 'dark_slate_gray', 'teal', 'dark_cyan',
+ 'aqua', 'cyan', 'light_cyan', 'dark_turquoise', 'turquoise',
+ 'medium_turquoise', 'pale_turquoise', 'aqua_marine', 'powder_blue', 'cadet_blue',
+ 'steel_blue', 'corn_flower_blue', 'deep_sky_blue', 'dodger_blue', 'light_blue',
+ 'sky_blue', 'light_sky_blue', 'midnight_blue', 'navy', 'dark_blue',
+ 'medium_blue', 'royal_blue', 'blue_violet', 'indigo', 'dark_slate_blue',
+ 'slate_blue', 'medium_slate_blue', 'medium_purple', 'dark_magenta', 'dark_violet',
+ 'dark_orchid', 'medium_orchid', 'purple', 'thistle', 'plum',
+ 'violet', 'magenta', 'orchid', 'medium_violet_red', 'pale_violet_red',
+ 'deep_pink', 'hot_pink', 'light_pink', 'pink', 'antique_white',
+ 'beige', 'bisque', 'blanched_almond', 'wheat', 'corn_silk',
+ 'lemon_chiffon', 'light_golden_rod_yellow', 'light_yellow', 'saddle_brown', 'sienna',
+ 'chocolate', 'peru', 'sandy_brown', 'burly_wood', 'tan',
+ 'rosy_brown', 'moccasin', 'navajo_white', 'peach_puff', 'misty_rose',
+ 'lavender_blush', 'linen', 'old_lace', 'papaya_whip', 'sea_shell',
+ 'mint_cream', 'slate_gray', 'light_slate_gray', 'light_steel_blue', 'lavender',
+ 'floral_white', 'alice_blue', 'ghost_white', 'honeydew', 'ivory',
+ 'azure', 'snow', 'silver', 'gainsboro', 'white_smoke',
]
+color_id = dict( [ (n, k) for k, n in enumerate(color_names) ] )
color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] )
######################################################################
-def generate(nb, height = 6, width = 8, max_nb_squares = 5, max_nb_statements = 10, many_colors = False):
+def all_properties(height, width, nb_squares, square_i, square_j, square_c):
+ s = [ ]
+
+ for r, c in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]:
+ s += [ f'there is {c}' ]
+
+ if square_i[r] >= height - height//3: s += [ f'{c} bottom' ]
+ if square_i[r] < height//3: s += [ f'{c} top' ]
+ if square_j[r] >= width - width//3: s += [ f'{c} right' ]
+ if square_j[r] < width//3: s += [ f'{c} left' ]
+
+ for t, d in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]:
+ if square_i[r] > square_i[t]: s += [ f'{c} below {d}' ]
+ if square_i[r] < square_i[t]: s += [ f'{c} above {d}' ]
+ if square_j[r] > square_j[t]: s += [ f'{c} right of {d}' ]
+ if square_j[r] < square_j[t]: s += [ f'{c} left of {d}' ]
+
+ return s
+
+######################################################################
+
+def generate(nb, height, width,
+ max_nb_squares = 5, max_nb_properties = 10,
+ many_colors = False):
nb_colors = len(color_tokens) - 1 if many_colors else max_nb_squares
nb_squares = torch.randint(max_nb_squares, (1,)) + 1
square_position = torch.randperm(height * width)[:nb_squares]
+ # color 0 is white and reserved for the background
square_c = torch.randperm(nb_colors)[:nb_squares] + 1
square_i = square_position.div(width, rounding_mode = 'floor')
square_j = square_position % width
img = [ 0 ] * height * width
for k in range(nb_squares): img[square_position[k]] = square_c[k]
- # generates all the true relations
-
- s = [ ]
+ # generates all the true properties
- for r, c in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]:
- s += [ f'there is {c}' ]
+ s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
- if square_i[r] >= height - height//3: s += [ f'{c} bottom' ]
- if square_i[r] < height//3: s += [ f'{c} top' ]
- if square_j[r] >= width - width//3: s += [ f'{c} right' ]
- if square_j[r] < width//3: s += [ f'{c} left' ]
+ # pick at most max_nb_properties at random
- for t, d in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]:
- if square_i[r] > square_i[t]: s += [ f'{c} below {d}' ]
- if square_i[r] < square_i[t]: s += [ f'{c} above {d}' ]
- if square_j[r] > square_j[t]: s += [ f'{c} right of {d}' ]
- if square_j[r] < square_j[t]: s += [ f'{c} left of {d}' ]
-
- # pick at most max_nb_statements at random
-
- nb_statements = torch.randint(max_nb_statements, (1,)) + 1
- s = ' <sep> '.join([ s[k] for k in torch.randperm(len(s))[:nb_statements] ] )
+ nb_properties = torch.randint(max_nb_properties, (1,)) + 1
+ s = ' <sep> '.join([ s[k] for k in torch.randperm(len(s))[:nb_properties] ] )
s += ' <img> ' + ' '.join([ f'{color_names[n]}' for n in img ])
descr += [ s ]
######################################################################
-def descr2img(descr, height = 6, width = 8):
+def descr2img(descr, height, width):
+
+ if type(descr) == list:
+ return torch.cat([ descr2img(d, height, width) for d in descr ], 0)
def token2color(t):
try:
except KeyError:
return [ 128, 128, 128 ]
- def img_descr(x):
- u = x.split('<img>', 1)
- return u[1] if len(u) > 1 else ''
-
- img = torch.full((len(descr), 3, height, width), 255)
- d = [ img_descr(x) for x in descr ]
- d = [ u.strip().split(' ')[:height * width] for u in d ]
- d = [ u + [ '<unk>' ] * (height * width - len(u)) for u in d ]
- d = [ [ token2color(t) for t in u ] for u in d ]
- img = torch.tensor(d).permute(0, 2, 1)
- img = img.reshape(img.size(0), 3, height, width)
+ d = descr.split('<img>', 1)
+ d = d[-1] if len(d) > 1 else ''
+ d = d.strip().split(' ')[:height * width]
+ d = d + [ '<unk>' ] * (height * width - len(d))
+ d = [ token2color(t) for t in d ]
+ img = torch.tensor(d).permute(1, 0)
+ img = img.reshape(1, 3, height, width)
return img
######################################################################
+def descr2properties(descr, height, width):
+
+ if type(descr) == list:
+ return [ descr2properties(d, height, width) for d in descr ]
+
+ d = descr.split('<img>', 1)
+ d = d[-1] if len(d) > 1 else ''
+ d = d.strip().split(' ')[:height * width]
+
+ seen = {}
+ if len(d) != height * width: return []
+ for k, x in enumerate(d):
+ if x != color_names[0]:
+ if x in color_tokens:
+ if x in seen: return []
+ else:
+ return []
+ seen[x] = (color_id[x], k // width, k % width)
+
+ square_c = torch.tensor( [ x[0] for x in seen.values() ] )
+ square_i = torch.tensor( [ x[1] for x in seen.values() ] )
+ square_j = torch.tensor( [ x[2] for x in seen.values() ] )
+
+ s = all_properties(height, width, len(seen), square_i, square_j, square_c)
+
+ return s
+
+######################################################################
+
+def nb_missing_properties(descr, height, width):
+ if type(descr) == list:
+ return [ nb_missing_properties(d, height, width) for d in descr ]
+
+ d = descr.split('<img>', 1)
+ if len(d) == 0: return 0
+ d = d[0].strip().split('<sep>')
+ d = [ x.strip() for x in d ]
+
+ requested_properties = set(d)
+ all_properties = set(descr2properties(descr, height, width))
+ missing_properties = requested_properties - all_properties
+
+ return (len(requested_properties), len(all_properties), len(missing_properties))
+
+######################################################################
+
if __name__ == '__main__':
descr = generate(nb = 5)
- for d in descr:
- print(d)
- print()
- img = descr2img(descr)
- print(img.size())
+ #print(descr2properties(descr))
+ print(nb_missing_properties(descr))
+
+ with open('picoclvr_example.txt', 'w') as f:
+ for d in descr:
+ f.write(f'{d}\n\n')
+ img = descr2img(descr)
torchvision.utils.save_image(img / 255.,
'picoclvr_example.png', nrow = 16, pad_value = 0.8)
import time
start_time = time.perf_counter()
- descr = generate(10000)
+ descr = generate(nb = 1000)
end_time = time.perf_counter()
print(f'{len(descr) / (end_time - start_time):.02f} samples per second')