X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=grid.py;h=f9f15578fb5a124382cc50943be3a1139aba8048;hb=9112db2ed7d8c262c4ef8298cf6637515675f967;hp=268f4eed9ba64a57899a575828ba98cde55e1e13;hpb=4395f9a90218819997c706de9505cda1c86ad507;p=mygptrnn.git diff --git a/grid.py b/grid.py index 268f4ee..f9f1557 100755 --- a/grid.py +++ b/grid.py @@ -9,10 +9,6 @@ import math import torch, torchvision import torch.nn.functional as F -name_shapes = ["A", "B", "C", "D", "E", "F"] - -name_colors = ["red", "yellow", "blue", "green", "white", "purple"] - ###################################################################### @@ -23,20 +19,160 @@ class GridFactory: max_nb_items=4, max_nb_transformations=3, nb_questions=4, + nb_shapes=6, + nb_colors=6, ): assert size % 2 == 0 self.size = size self.max_nb_items = max_nb_items self.max_nb_transformations = max_nb_transformations self.nb_questions = nb_questions + self.name_shapes = [chr(ord("A") + k) for k in range(nb_shapes)] + self.name_colors = [ + "red", + "yellow", + "blue", + "green", + "white", + "black", + "maroon", + "dark_red", + "brown", + "firebrick", + "crimson", + "tomato", + "coral", + "indian_red", + "light_coral", + "dark_salmon", + "salmon", + "light_salmon", + "orange_red", + "dark_orange", + "orange", + "gold", + "dark_golden_rod", + "golden_rod", + "pale_golden_rod", + "dark_khaki", + "khaki", + "olive", + "yellow_green", + "dark_olive_green", + "olive_drab", + "lawn_green", + "chartreuse", + "green_yellow", + "dark_green", + "forest_green", + "lime", + "lime_green", + "light_green", + "pale_green", + "dark_sea_green", + "medium_spring_green", + "spring_green", + "sea_green", + "medium_aqua_marine", + "medium_sea_green", + "light_sea_green", + "dark_slate_gray", + "teal", + "dark_cyan", + "aqua", + "cyan", + "light_cyan", + "dark_turquoise", + "turquoise", + "medium_turquoise", + "pale_turquoise", + "aqua_marine", + "powder_blue", + "cadet_blue", + "steel_blue", + "corn_flower_blue", + "deep_sky_blue", + "dodger_blue", + "light_blue", + "sky_blue", + "light_sky_blue", + "midnight_blue", + "navy", + "dark_blue", + "medium_blue", + "royal_blue", + "blue_violet", + "indigo", + "dark_slate_blue", + "slate_blue", + "medium_slate_blue", + "medium_purple", + "dark_magenta", + "dark_violet", + "dark_orchid", + "medium_orchid", + "purple", + "thistle", + "plum", + "violet", + "magenta", + "orchid", + "medium_violet_red", + "pale_violet_red", + "deep_pink", + "hot_pink", + "light_pink", + "pink", + "antique_white", + "beige", + "bisque", + "blanched_almond", + "wheat", + "corn_silk", + "lemon_chiffon", + "light_golden_rod_yellow", + "light_yellow", + "saddle_brown", + "sienna", + "chocolate", + "peru", + "sandy_brown", + "burly_wood", + "tan", + "rosy_brown", + "moccasin", + "navajo_white", + "peach_puff", + "misty_rose", + "lavender_blush", + "linen", + "old_lace", + "papaya_whip", + "sea_shell", + "mint_cream", + "slate_gray", + "light_slate_gray", + "light_steel_blue", + "lavender", + "floral_white", + "alice_blue", + "ghost_white", + "honeydew", + "ivory", + "azure", + "snow", + "silver", + "gainsboro", + "white_smoke", + ][:nb_colors] def generate_scene(self): nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2 col = torch.full((self.size * self.size,), -1) shp = torch.full((self.size * self.size,), -1) - a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items] - col[:nb_items] = a % len(name_colors) - shp[:nb_items] = a // len(name_colors) + a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items] + col[:nb_items] = a % len(self.name_colors) + shp[:nb_items] = a // len(self.name_colors) i = torch.randperm(self.size * self.size) col = col[i] shp = shp[i] @@ -76,12 +212,15 @@ class GridFactory: # for i in range(self.size): # for j in range(self.size): # if col[i,j] >= 0: - # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}") + # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}") for i in range(self.size): for j in range(self.size): if col[i, j] >= 0: - print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="") + print( + f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}", + end="", + ) elif j == 0: print(" +", end="") else: @@ -103,7 +242,7 @@ class GridFactory: for i in range(self.size): for j in range(self.size): if col[i, j] >= 0: - n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}" + n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}" properties += [f"a {n} at {i} {j}"] return properties @@ -116,7 +255,9 @@ class GridFactory: for i1 in range(self.size): for j1 in range(self.size): if col[i1, j1] >= 0: - n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}" + n1 = ( + f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}" + ) properties += [f"there is a {n1}"] if i1 < self.size // 2: properties += [f"a {n1} is in the top half"] @@ -129,7 +270,7 @@ class GridFactory: for i2 in range(self.size): for j2 in range(self.size): if col[i2, j2] >= 0: - n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}" + n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}" if i1 > i2: properties += [f"a {n1} is below a {n2}"] if i1 < i2: