class GridFactory:
def __init__(
self,
- height=4,
- width=4,
+ size=4,
max_nb_items=4,
- max_nb_transformations=4,
+ max_nb_transformations=3,
nb_questions=4,
):
- self.height = height
- self.width = width
+ self.size = size
self.max_nb_items = max_nb_items
self.max_nb_transformations = max_nb_transformations
self.nb_questions = nb_questions
def generate_scene(self):
nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
- col = torch.full((self.height * self.width,), -1)
- shp = torch.full((self.height * self.width,), -1)
+ col = torch.full((self.size * self.size,), -1)
+ shp = torch.full((self.size * self.size,), -1)
a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
col[:nb_items] = a % len(name_colors)
shp[:nb_items] = a // len(name_colors)
- i = torch.randperm(self.height * self.width)
+ i = torch.randperm(self.size * self.size)
col = col[i]
shp = shp[i]
- return col.reshape(self.height, self.width), shp.reshape(
- self.height, self.width
- )
+ return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
def random_transformations(self, scene):
col, shp = scene
+
descriptions = []
nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
transformations = torch.randint(5, (nb_transformations,))
col, shp = col.flip(1).t(), shp.flip(1).t()
descriptions += ["<chg> rotate 270 degrees"]
- return (col.contiguous(), shp.contiguous()), descriptions
+ col, shp = col.contiguous(), shp.contiguous()
+
+ return (col, shp), descriptions
def print_scene(self, scene):
col, shp = scene
- # for i in range(self.height):
- # for j in range(self.width):
+ # for i in range(self.size):
+ # for j in range(self.size):
# if col[i,j] >= 0:
# print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
- for i in range(self.height):
- for j in range(self.width):
+ for i in range(self.size):
+ for j in range(self.size):
if col[i, j] >= 0:
print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
elif j == 0:
print(" +", end="")
else:
print("-+", end="")
- if j < self.width - 1:
+ if j < self.size - 1:
print("--", end="")
else:
print("")
- if i < self.height - 1:
- for j in range(self.width - 1):
+ if i < self.size - 1:
+ for j in range(self.size - 1):
print(" | ", end="")
print(" |")
properties = []
- for i in range(self.height):
- for j in range(self.width):
+ for i in range(self.size):
+ for j in range(self.size):
if col[i, j] >= 0:
n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
properties += [f"a {n} at {i} {j}"]
properties = []
- for i1 in range(self.height):
- for j1 in range(self.width):
+ for i1 in range(self.size):
+ for j1 in range(self.size):
if col[i1, j1] >= 0:
n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
properties += [f"there is a {n1}"]
- if i1 < self.height // 2:
+ if i1 < self.size // 2:
properties += [f"a {n1} is in the top half"]
- if i1 >= self.height // 2:
+ if i1 >= self.size // 2:
properties += [f"a {n1} is in the bottom half"]
- if j1 < self.width // 2:
+ if j1 < self.size // 2:
properties += [f"a {n1} is in the left half"]
- if j1 >= self.width // 2:
+ if j1 >= self.size // 2:
properties += [f"a {n1} is in the right half"]
- for i2 in range(self.height):
- for j2 in range(self.width):
+ for i2 in range(self.size):
+ for j2 in range(self.size):
if col[i2, j2] >= 0:
n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
if i1 > i2:
scene, transformations = self.random_transformations(scene)
+ # transformations=[]
+
for a in range(10):
col, shp = scene
col, shp = col.view(-1), shp.view(-1)
p = torch.randperm(col.size(0))
col, shp = col[p], shp[p]
other_scene = (
- col.view(self.height, self.width),
- shp.view(self.height, self.width),
+ col.view(self.size, self.size),
+ shp.view(self.size, self.size),
)
# other_scene = self.generate_scene()
false = list(set(self.all_properties(other_scene)) - set(true))
if len(false) >= self.nb_questions:
break
- # print(f"{a=}")
-
if a < 10:
break