3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 import torch, torchvision
10 import torch.nn.functional as F
12 name_shapes = ["A", "B", "C", "D", "E", "F"]
14 name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
16 ######################################################################
25 max_nb_transformations=4,
30 self.max_nb_items = max_nb_items
31 self.nb_questions = nb_questions
33 def generate_scene(self):
34 nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
35 col = torch.full((self.height * self.width,), -1)
36 shp = torch.full((self.height * self.width,), -1)
37 a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items]
38 col[:nb_items] = a % len(name_colors)
39 shp[:nb_items] = a // len(name_colors)
40 i = torch.randperm(self.height * self.width)
43 return col.reshape(self.height, self.width), shp.reshape(
44 self.height, self.width
47 def random_transformations(self):
48 nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
50 def print_scene(self, scene):
53 # for i in range(self.height):
54 # for j in range(self.width):
56 # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}")
58 for i in range(self.height):
59 for j in range(self.width):
61 print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
66 if j < self.width - 1:
70 if i < self.height - 1:
71 for j in range(self.width - 1):
75 def grid_positions(self, scene):
80 for i in range(self.height):
81 for j in range(self.width):
83 n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}"
84 properties += [f"a {n} at {i} {j}"]
88 def all_properties(self, scene):
93 for i1 in range(self.height):
94 for j1 in range(self.width):
96 n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
97 properties += [f"there is a {n1}"]
98 if i1 < self.height // 2:
99 properties += [f"a {n1} is in the top half"]
100 if i1 >= self.height // 2:
101 properties += [f"a {n1} is in the bottom half"]
102 if j1 < self.width // 2:
103 properties += [f"a {n1} is in the left half"]
104 if j1 >= self.width // 2:
105 properties += [f"a {n1} is in the right half"]
106 for i2 in range(self.height):
107 for j2 in range(self.width):
109 n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}"
111 properties += [f"a {n1} is below a {n2}"]
113 properties += [f"a {n1} is above a {n2}"]
115 properties += [f"a {n1} is right of a {n2}"]
117 properties += [f"a {n1} is left of a {n2}"]
121 def generate_scene_and_questions(self):
124 scene = self.generate_scene()
125 true = self.all_properties(scene)
126 if len(true) >= self.nb_questions:
129 start = self.grid_positions(scene)
133 col, shp = col.view(-1), shp.view(-1)
134 p = torch.randperm(col.size(0))
135 col, shp = col[p], shp[p]
137 col.view(self.height, self.width),
138 shp.view(self.height, self.width),
140 # other_scene = self.generate_scene()
141 false = list(set(self.all_properties(other_scene)) - set(true))
142 if len(false) >= self.nb_questions:
150 true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
151 false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
152 true = ["<prop> " + q + " <true>" for q in true]
153 false = ["<prop> " + q + " <false>" for q in false]
156 questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
159 ["<obj> " + x for x in self.grid_positions(scene)] + questions
164 def generate_samples(self, nb, progress_bar=None):
168 if progress_bar is not None:
172 result.append(self.generate_scene_and_questions()[1])
177 ######################################################################
179 if __name__ == "__main__":
182 grid_factory = GridFactory()
184 start_time = time.perf_counter()
185 samples = grid_factory.generate_samples(10000)
186 end_time = time.perf_counter()
187 print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
189 scene, questions = grid_factory.generate_scene_and_questions()
190 grid_factory.print_scene(scene)
193 ######################################################################