Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 16 Jan 2024 13:31:03 +0000 (14:31 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 16 Jan 2024 13:31:03 +0000 (14:31 +0100)
grid.py

diff --git a/grid.py b/grid.py
index 268f4ee..2135710 100755 (executable)
--- a/grid.py
+++ b/grid.py
@@ -9,10 +9,6 @@ import math
 import torch, torchvision
 import torch.nn.functional as F
 
-name_shapes = ["A", "B", "C", "D", "E", "F"]
-
-name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
-
 ######################################################################
 
 
@@ -23,20 +19,24 @@ class GridFactory:
         max_nb_items=4,
         max_nb_transformations=3,
         nb_questions=4,
+        nb_shapes=6,
+        nb_colors=6,
     ):
         assert size % 2 == 0
         self.size = size
         self.max_nb_items = max_nb_items
         self.max_nb_transformations = max_nb_transformations
         self.nb_questions = nb_questions
+        self.name_shapes = ["A", "B", "C", "D", "E", "F"]
+        self.name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
 
     def generate_scene(self):
         nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
         col = torch.full((self.size * self.size,), -1)
         shp = torch.full((self.size * self.size,), -1)
-        a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
-        col[:nb_items] = a % len(name_colors)
-        shp[:nb_items] = a // len(name_colors)
+        a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items]
+        col[:nb_items] = a % len(self.name_colors)
+        shp[:nb_items] = a // len(self.name_colors)
         i = torch.randperm(self.size * self.size)
         col = col[i]
         shp = shp[i]
@@ -76,12 +76,15 @@ class GridFactory:
         # for i in range(self.size):
         # for j in range(self.size):
         # if col[i,j] >= 0:
-        # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
+        # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}")
 
         for i in range(self.size):
             for j in range(self.size):
                 if col[i, j] >= 0:
-                    print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
+                    print(
+                        f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}",
+                        end="",
+                    )
                 elif j == 0:
                     print(" +", end="")
                 else:
@@ -103,7 +106,7 @@ class GridFactory:
         for i in range(self.size):
             for j in range(self.size):
                 if col[i, j] >= 0:
-                    n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
+                    n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}"
                     properties += [f"a {n} at {i} {j}"]
 
         return properties
@@ -116,7 +119,9 @@ class GridFactory:
         for i1 in range(self.size):
             for j1 in range(self.size):
                 if col[i1, j1] >= 0:
-                    n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
+                    n1 = (
+                        f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}"
+                    )
                     properties += [f"there is a {n1}"]
                     if i1 < self.size // 2:
                         properties += [f"a {n1} is in the top half"]
@@ -129,7 +134,7 @@ class GridFactory:
                     for i2 in range(self.size):
                         for j2 in range(self.size):
                             if col[i2, j2] >= 0:
-                                n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
+                                n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}"
                                 if i1 > i2:
                                     properties += [f"a {n1} is below a {n2}"]
                                 if i1 < i2: