import torch, torchvision
import torch.nn.functional as F
-name_shapes = ["A", "B", "C", "D", "E", "F"]
-
-name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
-
######################################################################
max_nb_items=4,
max_nb_transformations=3,
nb_questions=4,
+ nb_shapes=6,
+ nb_colors=6,
):
assert size % 2 == 0
self.size = size
self.max_nb_items = max_nb_items
self.max_nb_transformations = max_nb_transformations
self.nb_questions = nb_questions
+ self.name_shapes = ["A", "B", "C", "D", "E", "F"]
+ self.name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
def generate_scene(self):
nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
col = torch.full((self.size * self.size,), -1)
shp = torch.full((self.size * self.size,), -1)
- a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
- col[:nb_items] = a % len(name_colors)
- shp[:nb_items] = a // len(name_colors)
+ a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items]
+ col[:nb_items] = a % len(self.name_colors)
+ shp[:nb_items] = a // len(self.name_colors)
i = torch.randperm(self.size * self.size)
col = col[i]
shp = shp[i]
# for i in range(self.size):
# for j in range(self.size):
# if col[i,j] >= 0:
- # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
+ # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}")
for i in range(self.size):
for j in range(self.size):
if col[i, j] >= 0:
- print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
+ print(
+ f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}",
+ end="",
+ )
elif j == 0:
print(" +", end="")
else:
for i in range(self.size):
for j in range(self.size):
if col[i, j] >= 0:
- n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
+ n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}"
properties += [f"a {n} at {i} {j}"]
return properties
for i1 in range(self.size):
for j1 in range(self.size):
if col[i1, j1] >= 0:
- n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
+ n1 = (
+ f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}"
+ )
properties += [f"there is a {n1}"]
if i1 < self.size // 2:
properties += [f"a {n1} is in the top half"]
for i2 in range(self.size):
for j2 in range(self.size):
if col[i2, j2] >= 0:
- n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
+ n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}"
if i1 > i2:
properties += [f"a {n1} is below a {n2}"]
if i1 < i2: