3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 import torch, torchvision
10 import torch.nn.functional as F
12 ######################################################################
20 max_nb_transformations=3,
27 self.max_nb_items = max_nb_items
28 self.max_nb_transformations = max_nb_transformations
29 self.nb_questions = nb_questions
30 self.name_shapes = [chr(ord("A") + k) for k in range(nb_shapes)]
73 "medium_spring_green",
133 "light_golden_rod_yellow",
169 def generate_scene(self):
170 nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
171 col = torch.full((self.size * self.size,), -1)
172 shp = torch.full((self.size * self.size,), -1)
173 a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items]
174 col[:nb_items] = a % len(self.name_colors)
175 shp[:nb_items] = a // len(self.name_colors)
176 i = torch.randperm(self.size * self.size)
179 return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
181 def random_transformations(self, scene):
185 nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
186 transformations = torch.randint(5, (nb_transformations,))
188 for t in transformations:
190 col, shp = col.flip(0), shp.flip(0)
191 descriptions += ["<chg> vertical flip"]
193 col, shp = col.flip(1), shp.flip(1)
194 descriptions += ["<chg> horizontal flip"]
196 col, shp = col.flip(0).t(), shp.flip(0).t()
197 descriptions += ["<chg> rotate 90 degrees"]
199 col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
200 descriptions += ["<chg> rotate 180 degrees"]
202 col, shp = col.flip(1).t(), shp.flip(1).t()
203 descriptions += ["<chg> rotate 270 degrees"]
205 col, shp = col.contiguous(), shp.contiguous()
207 return (col, shp), descriptions
209 def print_scene(self, scene):
212 # for i in range(self.size):
213 # for j in range(self.size):
215 # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}")
217 for i in range(self.size):
218 for j in range(self.size):
221 f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}",
228 if j < self.size - 1:
232 if i < self.size - 1:
233 for j in range(self.size - 1):
237 def grid_positions(self, scene):
242 for i in range(self.size):
243 for j in range(self.size):
245 n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}"
246 properties += [f"a {n} at {i} {j}"]
250 def all_properties(self, scene):
255 for i1 in range(self.size):
256 for j1 in range(self.size):
259 f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}"
261 properties += [f"there is a {n1}"]
262 if i1 < self.size // 2:
263 properties += [f"a {n1} is in the top half"]
264 if i1 >= self.size // 2:
265 properties += [f"a {n1} is in the bottom half"]
266 if j1 < self.size // 2:
267 properties += [f"a {n1} is in the left half"]
268 if j1 >= self.size // 2:
269 properties += [f"a {n1} is in the right half"]
270 for i2 in range(self.size):
271 for j2 in range(self.size):
273 n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}"
275 properties += [f"a {n1} is below a {n2}"]
277 properties += [f"a {n1} is above a {n2}"]
279 properties += [f"a {n1} is right of a {n2}"]
281 properties += [f"a {n1} is left of a {n2}"]
282 if abs(i1 - i2) + abs(j1 - j2) == 1:
283 properties += [f"a {n1} is next to a {n2}"]
287 def generate_scene_and_questions(self):
290 start_scene = self.generate_scene()
291 scene, transformations = self.random_transformations(start_scene)
292 true = self.all_properties(scene)
293 if len(true) >= self.nb_questions:
298 col, shp = col.view(-1), shp.view(-1)
299 p = torch.randperm(col.size(0))
300 col, shp = col[p], shp[p]
302 col.view(self.size, self.size),
303 shp.view(self.size, self.size),
306 false = self.all_properties(other_scene)
308 # We sometime add properties from a totally different
309 # scene to have negative "there is a xxx xxx"
311 if torch.rand(1).item() < 0.2:
312 other_scene = self.generate_scene()
313 false += self.all_properties(other_scene)
315 false = list(set(false) - set(true))
316 if len(false) >= self.nb_questions:
322 true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
323 false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
324 true = ["<prop> " + q + " <ans> true" for q in true]
325 false = ["<prop> " + q + " <ans> false" for q in false]
328 questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
331 ["<obj> " + x for x in self.grid_positions(start_scene)]
336 return start_scene, scene, result
338 def generate_samples(self, nb, progress_bar=None):
342 if progress_bar is not None:
346 result.append(self.generate_scene_and_questions()[2])
351 ######################################################################
353 if __name__ == "__main__":
356 grid_factory = GridFactory()
358 # start_time = time.perf_counter()
359 # samples = grid_factory.generate_samples(10000)
360 # end_time = time.perf_counter()
361 # print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
363 start_scene, scene, questions = grid_factory.generate_scene_and_questions()
365 print("-- Original scene -----------------------------")
367 grid_factory.print_scene(start_scene)
369 print("-- Transformed scene --------------------------")
371 grid_factory.print_scene(scene)
373 print("-- Sequence -----------------------------------")
377 ######################################################################