3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 import torch, torchvision
10 import torch.nn.functional as F
12 name_shapes = ["A", "B", "C", "D", "E", "F"]
14 name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
16 ######################################################################
25 max_nb_transformations=4,
30 self.max_nb_items = max_nb_items
31 self.max_nb_transformations = max_nb_transformations
32 self.nb_questions = nb_questions
34 def generate_scene(self):
35 nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
36 col = torch.full((self.height * self.width,), -1)
37 shp = torch.full((self.height * self.width,), -1)
38 a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
39 col[:nb_items] = a % len(name_colors)
40 shp[:nb_items] = a // len(name_colors)
41 i = torch.randperm(self.height * self.width)
44 return col.reshape(self.height, self.width), shp.reshape(
45 self.height, self.width
48 def random_transformations(self, scene):
51 nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
52 transformations = torch.randint(5, (nb_transformations,))
54 for t in transformations:
56 col, shp = col.flip(0), shp.flip(0)
57 descriptions += ["<chg> vertical flip"]
59 col, shp = col.flip(1), shp.flip(1)
60 descriptions += ["<chg> horizontal flip"]
62 col, shp = col.flip(0).t(), shp.flip(0).t()
63 descriptions += ["<chg> rotate 90 degrees"]
65 col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
66 descriptions += ["<chg> rotate 180 degrees"]
68 col, shp = col.flip(1).t(), shp.flip(1).t()
69 descriptions += ["<chg> rotate 270 degrees"]
71 return (col.contiguous(), shp.contiguous()), descriptions
73 def print_scene(self, scene):
76 # for i in range(self.height):
77 # for j in range(self.width):
79 # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
81 for i in range(self.height):
82 for j in range(self.width):
84 print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
89 if j < self.width - 1:
93 if i < self.height - 1:
94 for j in range(self.width - 1):
98 def grid_positions(self, scene):
103 for i in range(self.height):
104 for j in range(self.width):
106 n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
107 properties += [f"a {n} at {i} {j}"]
111 def all_properties(self, scene):
116 for i1 in range(self.height):
117 for j1 in range(self.width):
119 n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
120 properties += [f"there is a {n1}"]
121 if i1 < self.height // 2:
122 properties += [f"a {n1} is in the top half"]
123 if i1 >= self.height // 2:
124 properties += [f"a {n1} is in the bottom half"]
125 if j1 < self.width // 2:
126 properties += [f"a {n1} is in the left half"]
127 if j1 >= self.width // 2:
128 properties += [f"a {n1} is in the right half"]
129 for i2 in range(self.height):
130 for j2 in range(self.width):
132 n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
134 properties += [f"a {n1} is below a {n2}"]
136 properties += [f"a {n1} is above a {n2}"]
138 properties += [f"a {n1} is right of a {n2}"]
140 properties += [f"a {n1} is left of a {n2}"]
144 def generate_scene_and_questions(self):
147 scene = self.generate_scene()
148 true = self.all_properties(scene)
149 if len(true) >= self.nb_questions:
152 start = self.grid_positions(scene)
154 scene, transformations = self.random_transformations(scene)
158 col, shp = col.view(-1), shp.view(-1)
159 p = torch.randperm(col.size(0))
160 col, shp = col[p], shp[p]
162 col.view(self.height, self.width),
163 shp.view(self.height, self.width),
165 # other_scene = self.generate_scene()
166 false = list(set(self.all_properties(other_scene)) - set(true))
167 if len(false) >= self.nb_questions:
175 true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
176 false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
177 true = ["<prop> " + q + " <true>" for q in true]
178 false = ["<prop> " + q + " <false>" for q in false]
181 questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
184 ["<obj> " + x for x in self.grid_positions(scene)]
191 def generate_samples(self, nb, progress_bar=None):
195 if progress_bar is not None:
199 result.append(self.generate_scene_and_questions()[1])
204 ######################################################################
206 if __name__ == "__main__":
209 grid_factory = GridFactory()
211 start_time = time.perf_counter()
212 samples = grid_factory.generate_samples(10000)
213 end_time = time.perf_counter()
214 print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
216 scene, questions = grid_factory.generate_scene_and_questions()
217 grid_factory.print_scene(scene)
220 ######################################################################