3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 import torch, torchvision
10 import torch.nn.functional as F
12 ######################################################################
20 max_nb_transformations=3,
27 self.max_nb_items = max_nb_items
28 self.max_nb_transformations = max_nb_transformations
29 self.nb_questions = nb_questions
30 self.name_shapes = ["A", "B", "C", "D", "E", "F"]
31 self.name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
33 def generate_scene(self):
34 nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
35 col = torch.full((self.size * self.size,), -1)
36 shp = torch.full((self.size * self.size,), -1)
37 a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items]
38 col[:nb_items] = a % len(self.name_colors)
39 shp[:nb_items] = a // len(self.name_colors)
40 i = torch.randperm(self.size * self.size)
43 return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
45 def random_transformations(self, scene):
49 nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
50 transformations = torch.randint(5, (nb_transformations,))
52 for t in transformations:
54 col, shp = col.flip(0), shp.flip(0)
55 descriptions += ["<chg> vertical flip"]
57 col, shp = col.flip(1), shp.flip(1)
58 descriptions += ["<chg> horizontal flip"]
60 col, shp = col.flip(0).t(), shp.flip(0).t()
61 descriptions += ["<chg> rotate 90 degrees"]
63 col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
64 descriptions += ["<chg> rotate 180 degrees"]
66 col, shp = col.flip(1).t(), shp.flip(1).t()
67 descriptions += ["<chg> rotate 270 degrees"]
69 col, shp = col.contiguous(), shp.contiguous()
71 return (col, shp), descriptions
73 def print_scene(self, scene):
76 # for i in range(self.size):
77 # for j in range(self.size):
79 # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}")
81 for i in range(self.size):
82 for j in range(self.size):
85 f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}",
97 for j in range(self.size - 1):
101 def grid_positions(self, scene):
106 for i in range(self.size):
107 for j in range(self.size):
109 n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}"
110 properties += [f"a {n} at {i} {j}"]
114 def all_properties(self, scene):
119 for i1 in range(self.size):
120 for j1 in range(self.size):
123 f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}"
125 properties += [f"there is a {n1}"]
126 if i1 < self.size // 2:
127 properties += [f"a {n1} is in the top half"]
128 if i1 >= self.size // 2:
129 properties += [f"a {n1} is in the bottom half"]
130 if j1 < self.size // 2:
131 properties += [f"a {n1} is in the left half"]
132 if j1 >= self.size // 2:
133 properties += [f"a {n1} is in the right half"]
134 for i2 in range(self.size):
135 for j2 in range(self.size):
137 n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}"
139 properties += [f"a {n1} is below a {n2}"]
141 properties += [f"a {n1} is above a {n2}"]
143 properties += [f"a {n1} is right of a {n2}"]
145 properties += [f"a {n1} is left of a {n2}"]
146 if abs(i1 - i2) + abs(j1 - j2) == 1:
147 properties += [f"a {n1} is next to a {n2}"]
151 def generate_scene_and_questions(self):
154 start_scene = self.generate_scene()
155 scene, transformations = self.random_transformations(start_scene)
156 true = self.all_properties(scene)
157 if len(true) >= self.nb_questions:
162 col, shp = col.view(-1), shp.view(-1)
163 p = torch.randperm(col.size(0))
164 col, shp = col[p], shp[p]
166 col.view(self.size, self.size),
167 shp.view(self.size, self.size),
170 false = self.all_properties(other_scene)
172 # We sometime add properties from a totally different
173 # scene to have negative "there is a xxx xxx"
175 if torch.rand(1).item() < 0.2:
176 other_scene = self.generate_scene()
177 false += self.all_properties(other_scene)
179 false = list(set(false) - set(true))
180 if len(false) >= self.nb_questions:
186 true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
187 false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
188 true = ["<prop> " + q + " <ans> true" for q in true]
189 false = ["<prop> " + q + " <ans> false" for q in false]
192 questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
195 ["<obj> " + x for x in self.grid_positions(start_scene)]
200 return start_scene, scene, result
202 def generate_samples(self, nb, progress_bar=None):
206 if progress_bar is not None:
210 result.append(self.generate_scene_and_questions()[2])
215 ######################################################################
217 if __name__ == "__main__":
220 grid_factory = GridFactory()
222 # start_time = time.perf_counter()
223 # samples = grid_factory.generate_samples(10000)
224 # end_time = time.perf_counter()
225 # print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
227 start_scene, scene, questions = grid_factory.generate_scene_and_questions()
229 print("-- Original scene -----------------------------")
231 grid_factory.print_scene(start_scene)
233 print("-- Transformed scene --------------------------")
235 grid_factory.print_scene(scene)
237 print("-- Sequence -----------------------------------")
241 ######################################################################