3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 import torch, torchvision
10 import torch.nn.functional as F
12 ######################################################################
20 max_nb_transformations=3,
28 self.max_nb_items = max_nb_items
29 self.max_nb_transformations = max_nb_transformations
30 self.nb_questions = nb_questions
31 self.nb_play_steps = nb_play_steps
32 self.name_shapes = ["A", "B", "C", "D", "E", "F"]
33 self.name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
34 self.vname_shapes = ["vA", "vB", "vC", "vD", "vE", "vF"]
35 self.vname_colors = ["vred", "vyellow", "vblue", "vgreen", "vwhite", "vpurple"]
37 def generate_scene(self):
38 nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
39 col = torch.full((self.size * self.size,), -1)
40 shp = torch.full((self.size * self.size,), -1)
41 a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items]
42 col[:nb_items] = a % len(self.name_colors)
43 shp[:nb_items] = a // len(self.name_colors)
44 i = torch.randperm(self.size * self.size)
47 return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
49 def random_object_move(self, scene):
52 a = (col.flatten() >= 0).nonzero()
53 a = a[torch.randint(a.size(0), (1,)).item()]
54 i, j = a // self.size, a % self.size
56 dst = [(i, j), (i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]
63 and col[x[0], x[1]] < 0,
68 ni, nj = dst[torch.randint(len(dst), (1,)).item()]
69 col[ni, nj] = col[i, j]
70 shp[ni, nj] = shp[i, j]
77 def transformation(self, t, scene):
80 col, shp = col.flip(0), shp.flip(0)
81 description = "<chg> vertical flip"
83 col, shp = col.flip(1), shp.flip(1)
84 description = "<chg> horizontal flip"
86 col, shp = col.flip(0).t(), shp.flip(0).t()
87 description = "<chg> rotate 90 degrees"
89 col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
90 description = "<chg> rotate 180 degrees"
92 col, shp = col.flip(1).t(), shp.flip(1).t()
93 description = "<chg> rotate 270 degrees"
95 return (col.contiguous(), shp.contiguous()), description
97 def random_transformations(self, scene):
99 nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
100 transformations = torch.randint(5, (nb_transformations,))
102 for t in transformations:
103 scene, description = self.transformation(t, scene)
104 descriptions += [description]
106 return scene, descriptions
108 def visual_scene2str(self, scene):
111 for i in range(self.size):
113 for j in range(self.size):
115 s += [self.vname_colors[col[i, j]], self.vname_shapes[shp[i, j]]]
118 r += s # .append(" ".join(s))
121 def print_scene(self, scene):
124 # for i in range(self.size):
125 # for j in range(self.size):
127 # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}")
129 for i in range(self.size):
130 for j in range(self.size):
133 f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}",
140 if j < self.size - 1:
144 if i < self.size - 1:
145 for j in range(self.size - 1):
149 def grid_positions(self, scene):
154 for i in range(self.size):
155 for j in range(self.size):
157 n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}"
158 properties += [f"a {n} at {i} {j}"]
162 def all_properties(self, scene):
167 for i1 in range(self.size):
168 for j1 in range(self.size):
171 f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}"
173 properties += [f"there is a {n1}"]
174 if i1 < self.size // 2:
175 properties += [f"a {n1} is in the top half"]
176 if i1 >= self.size // 2:
177 properties += [f"a {n1} is in the bottom half"]
178 if j1 < self.size // 2:
179 properties += [f"a {n1} is in the left half"]
180 if j1 >= self.size // 2:
181 properties += [f"a {n1} is in the right half"]
182 for i2 in range(self.size):
183 for j2 in range(self.size):
185 n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}"
187 properties += [f"a {n1} is below a {n2}"]
189 properties += [f"a {n1} is above a {n2}"]
191 properties += [f"a {n1} is right of a {n2}"]
193 properties += [f"a {n1} is left of a {n2}"]
194 if abs(i1 - i2) + abs(j1 - j2) == 1:
195 properties += [f"a {n1} is next to a {n2}"]
199 def generate_scene_and_play(self):
200 scene = self.generate_scene()
201 steps = [self.visual_scene2str(scene)]
202 for t in range(self.nb_play_steps - 1):
203 if torch.randint(4, (1,)).item() == 0:
204 scene, _ = self.transformation(torch.randint(5, (1,)), scene)
206 scene = self.random_object_move(scene)
207 steps.append(self.visual_scene2str(scene))
208 return " | ".join(steps)
210 def generate_scene_and_questions(self):
212 # We generate scenes until we get one with enough
216 start_scene = self.generate_scene()
217 scene, transformations = self.random_transformations(start_scene)
218 true = self.all_properties(scene)
219 if len(true) >= self.nb_questions:
222 # We generate a bunch of false properties by shuffling the
223 # scene and sometimes adding properties from totally
224 # different scenes. We try ten times to get enough false
225 # properties and go back to generating the scene if we do
230 col, shp = col.view(-1), shp.view(-1)
231 p = torch.randperm(col.size(0))
232 col, shp = col[p], shp[p]
234 col.view(self.size, self.size),
235 shp.view(self.size, self.size),
238 false = self.all_properties(other_scene)
240 # We sometime add properties from a totally different
241 # scene to have negative "there is a xxx xxx"
244 if torch.rand(1).item() < 0.2:
245 other_scene = self.generate_scene()
246 false += self.all_properties(other_scene)
248 false = list(set(false) - set(true))
249 if len(false) >= self.nb_questions:
255 true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
256 false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
257 true = ["<prop> " + q + " <ans> true" for q in true]
258 false = ["<prop> " + q + " <ans> false" for q in false]
261 questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
264 ["<obj> " + x for x in self.grid_positions(start_scene)]
269 return start_scene, scene, result
271 def generate_samples(self, nb, fraction_play=0.0, progress_bar=None):
274 play = torch.rand(nb) < fraction_play
275 if progress_bar is not None:
276 play = progress_bar(play)
280 result.append(self.generate_scene_and_play())
282 result.append(self.generate_scene_and_questions()[2])
287 ######################################################################
289 if __name__ == "__main__":
292 grid_factory = GridFactory()
294 # start_time = time.perf_counter()
295 # samples = grid_factory.generate_samples(10000)
296 # end_time = time.perf_counter()
297 # print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
299 start_scene, scene, questions = grid_factory.generate_scene_and_questions()
301 print("-- Original scene -----------------------------")
303 grid_factory.print_scene(start_scene)
305 print("-- Transformed scene --------------------------")
307 grid_factory.print_scene(scene)
309 print("-- Sequence -----------------------------------")
313 # print(grid_factory.visual_scene2str(scene))
315 # grid_factory.print_scene(scene)
317 # scene = grid_factory.random_object_move(scene)
319 # grid_factory.print_scene(scene)
321 print(grid_factory.generate_scene_and_play())
323 ######################################################################