import torch, torchvision
import torch.nn.functional as F
-name_shapes = ["A", "B", "C", "D", "E", "F"]
-
-name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
-
######################################################################
max_nb_items=4,
max_nb_transformations=3,
nb_questions=4,
+ nb_shapes=6,
+ nb_colors=6,
):
assert size % 2 == 0
self.size = size
self.max_nb_items = max_nb_items
self.max_nb_transformations = max_nb_transformations
self.nb_questions = nb_questions
+ self.name_shapes = [chr(ord("A") + k) for k in range(nb_shapes)]
+ self.name_colors = [
+ "red",
+ "yellow",
+ "blue",
+ "green",
+ "white",
+ "black",
+ "maroon",
+ "dark_red",
+ "brown",
+ "firebrick",
+ "crimson",
+ "tomato",
+ "coral",
+ "indian_red",
+ "light_coral",
+ "dark_salmon",
+ "salmon",
+ "light_salmon",
+ "orange_red",
+ "dark_orange",
+ "orange",
+ "gold",
+ "dark_golden_rod",
+ "golden_rod",
+ "pale_golden_rod",
+ "dark_khaki",
+ "khaki",
+ "olive",
+ "yellow_green",
+ "dark_olive_green",
+ "olive_drab",
+ "lawn_green",
+ "chartreuse",
+ "green_yellow",
+ "dark_green",
+ "forest_green",
+ "lime",
+ "lime_green",
+ "light_green",
+ "pale_green",
+ "dark_sea_green",
+ "medium_spring_green",
+ "spring_green",
+ "sea_green",
+ "medium_aqua_marine",
+ "medium_sea_green",
+ "light_sea_green",
+ "dark_slate_gray",
+ "teal",
+ "dark_cyan",
+ "aqua",
+ "cyan",
+ "light_cyan",
+ "dark_turquoise",
+ "turquoise",
+ "medium_turquoise",
+ "pale_turquoise",
+ "aqua_marine",
+ "powder_blue",
+ "cadet_blue",
+ "steel_blue",
+ "corn_flower_blue",
+ "deep_sky_blue",
+ "dodger_blue",
+ "light_blue",
+ "sky_blue",
+ "light_sky_blue",
+ "midnight_blue",
+ "navy",
+ "dark_blue",
+ "medium_blue",
+ "royal_blue",
+ "blue_violet",
+ "indigo",
+ "dark_slate_blue",
+ "slate_blue",
+ "medium_slate_blue",
+ "medium_purple",
+ "dark_magenta",
+ "dark_violet",
+ "dark_orchid",
+ "medium_orchid",
+ "purple",
+ "thistle",
+ "plum",
+ "violet",
+ "magenta",
+ "orchid",
+ "medium_violet_red",
+ "pale_violet_red",
+ "deep_pink",
+ "hot_pink",
+ "light_pink",
+ "pink",
+ "antique_white",
+ "beige",
+ "bisque",
+ "blanched_almond",
+ "wheat",
+ "corn_silk",
+ "lemon_chiffon",
+ "light_golden_rod_yellow",
+ "light_yellow",
+ "saddle_brown",
+ "sienna",
+ "chocolate",
+ "peru",
+ "sandy_brown",
+ "burly_wood",
+ "tan",
+ "rosy_brown",
+ "moccasin",
+ "navajo_white",
+ "peach_puff",
+ "misty_rose",
+ "lavender_blush",
+ "linen",
+ "old_lace",
+ "papaya_whip",
+ "sea_shell",
+ "mint_cream",
+ "slate_gray",
+ "light_slate_gray",
+ "light_steel_blue",
+ "lavender",
+ "floral_white",
+ "alice_blue",
+ "ghost_white",
+ "honeydew",
+ "ivory",
+ "azure",
+ "snow",
+ "silver",
+ "gainsboro",
+ "white_smoke",
+ ][:nb_colors]
def generate_scene(self):
nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
col = torch.full((self.size * self.size,), -1)
shp = torch.full((self.size * self.size,), -1)
- a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
- col[:nb_items] = a % len(name_colors)
- shp[:nb_items] = a // len(name_colors)
+ a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items]
+ col[:nb_items] = a % len(self.name_colors)
+ shp[:nb_items] = a // len(self.name_colors)
i = torch.randperm(self.size * self.size)
col = col[i]
shp = shp[i]
# for i in range(self.size):
# for j in range(self.size):
# if col[i,j] >= 0:
- # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
+ # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}")
for i in range(self.size):
for j in range(self.size):
if col[i, j] >= 0:
- print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
+ print(
+ f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}",
+ end="",
+ )
elif j == 0:
print(" +", end="")
else:
for i in range(self.size):
for j in range(self.size):
if col[i, j] >= 0:
- n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
+ n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}"
properties += [f"a {n} at {i} {j}"]
return properties
for i1 in range(self.size):
for j1 in range(self.size):
if col[i1, j1] >= 0:
- n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
+ n1 = (
+ f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}"
+ )
properties += [f"there is a {n1}"]
if i1 < self.size // 2:
properties += [f"a {n1} is in the top half"]
for i2 in range(self.size):
for j2 in range(self.size):
if col[i2, j2] >= 0:
- n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
+ n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}"
if i1 > i2:
properties += [f"a {n1} is below a {n2}"]
if i1 < i2: