3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 import torch, torchvision
10 import torch.nn.functional as F
12 name_shapes = ["A", "B", "C", "D", "E", "F"]
14 name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
16 ######################################################################
24 max_nb_transformations=3,
28 self.max_nb_items = max_nb_items
29 self.max_nb_transformations = max_nb_transformations
30 self.nb_questions = nb_questions
32 def generate_scene(self):
33 nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
34 col = torch.full((self.size * self.size,), -1)
35 shp = torch.full((self.size * self.size,), -1)
36 a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
37 col[:nb_items] = a % len(name_colors)
38 shp[:nb_items] = a // len(name_colors)
39 i = torch.randperm(self.size * self.size)
42 return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
44 def random_transformations(self, scene):
48 nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
49 transformations = torch.randint(5, (nb_transformations,))
51 for t in transformations:
53 col, shp = col.flip(0), shp.flip(0)
54 descriptions += ["<chg> vertical flip"]
56 col, shp = col.flip(1), shp.flip(1)
57 descriptions += ["<chg> horizontal flip"]
59 col, shp = col.flip(0).t(), shp.flip(0).t()
60 descriptions += ["<chg> rotate 90 degrees"]
62 col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
63 descriptions += ["<chg> rotate 180 degrees"]
65 col, shp = col.flip(1).t(), shp.flip(1).t()
66 descriptions += ["<chg> rotate 270 degrees"]
68 col, shp = col.contiguous(), shp.contiguous()
70 return (col, shp), descriptions
72 def print_scene(self, scene):
75 # for i in range(self.size):
76 # for j in range(self.size):
78 # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
80 for i in range(self.size):
81 for j in range(self.size):
83 print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
93 for j in range(self.size - 1):
97 def grid_positions(self, scene):
102 for i in range(self.size):
103 for j in range(self.size):
105 n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
106 properties += [f"a {n} at {i} {j}"]
110 def all_properties(self, scene):
115 for i1 in range(self.size):
116 for j1 in range(self.size):
118 n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
119 properties += [f"there is a {n1}"]
120 if i1 < self.size // 2:
121 properties += [f"a {n1} is in the top half"]
122 if i1 >= self.size // 2:
123 properties += [f"a {n1} is in the bottom half"]
124 if j1 < self.size // 2:
125 properties += [f"a {n1} is in the left half"]
126 if j1 >= self.size // 2:
127 properties += [f"a {n1} is in the right half"]
128 for i2 in range(self.size):
129 for j2 in range(self.size):
131 n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
133 properties += [f"a {n1} is below a {n2}"]
135 properties += [f"a {n1} is above a {n2}"]
137 properties += [f"a {n1} is right of a {n2}"]
139 properties += [f"a {n1} is left of a {n2}"]
143 def generate_scene_and_questions(self):
146 start_scene = self.generate_scene()
147 true = self.all_properties(start_scene)
148 if len(true) >= self.nb_questions:
151 start = self.grid_positions(start_scene)
153 scene, transformations = self.random_transformations(start_scene)
159 col, shp = col.view(-1), shp.view(-1)
160 p = torch.randperm(col.size(0))
161 col, shp = col[p], shp[p]
163 col.view(self.size, self.size),
164 shp.view(self.size, self.size),
166 # other_scene = self.generate_scene()
167 false = list(set(self.all_properties(other_scene)) - set(true))
168 if len(false) >= self.nb_questions:
174 true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
175 false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
176 true = ["<prop> " + q + " <true>" for q in true]
177 false = ["<prop> " + q + " <false>" for q in false]
180 questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
183 ["<obj> " + x for x in self.grid_positions(scene)]
188 return start_scene, scene, result
190 def generate_samples(self, nb, progress_bar=None):
194 if progress_bar is not None:
198 result.append(self.generate_scene_and_questions()[2])
203 ######################################################################
205 if __name__ == "__main__":
208 grid_factory = GridFactory()
210 # start_time = time.perf_counter()
211 # samples = grid_factory.generate_samples(10000)
212 # end_time = time.perf_counter()
213 # print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
215 start_scene, scene, questions = grid_factory.generate_scene_and_questions()
216 print("-- Original scene -----------------------------")
217 grid_factory.print_scene(start_scene)
218 print("-- Transformed scene --------------------------")
219 grid_factory.print_scene(scene)
220 print("-- Sequence -----------------------------------")
223 ######################################################################