X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=grid.py;h=433cfd5a5babea1c7605e5dfe63ca66245f3120a;hb=ce969e8372fb161d86be29042a20b044ee6efe2a;hp=08ddc232f7fa7a067290482c7ec79ef059702da6;hpb=b87078aec53ead1e0a3ca44d4ac46c319bbcd63e;p=picoclvr.git diff --git a/grid.py b/grid.py index 08ddc23..433cfd5 100755 --- a/grid.py +++ b/grid.py @@ -28,6 +28,7 @@ class GridFactory: self.height = height self.width = width self.max_nb_items = max_nb_items + self.max_nb_transformations = max_nb_transformations self.nb_questions = nb_questions def generate_scene(self): @@ -44,8 +45,30 @@ class GridFactory: self.height, self.width ) - def random_transformations(self): + def random_transformations(self, scene): + col, shp = scene + descriptions = [] nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item() + transformations = torch.randint(5, (nb_transformations,)) + + for t in transformations: + if t == 0: + col, shp = col.flip(0), shp.flip(0) + descriptions += [" vertical flip"] + elif t == 1: + col, shp = col.flip(1), shp.flip(1) + descriptions += [" horizontal flip"] + elif t == 2: + col, shp = col.flip(0).t(), shp.flip(0).t() + descriptions += [" rotate 90 degrees"] + elif t == 3: + col, shp = col.flip(0).flip(1), shp.flip(0).flip(1) + descriptions += [" rotate 180 degrees"] + elif t == 4: + col, shp = col.flip(1).t(), shp.flip(1).t() + descriptions += [" rotate 270 degrees"] + + return (col.contiguous(), shp.contiguous()), descriptions def print_scene(self, scene): col, shp = scene @@ -118,7 +141,7 @@ class GridFactory: return properties - def generate_example(self): + def generate_scene_and_questions(self): while True: while True: scene = self.generate_scene() @@ -128,6 +151,8 @@ class GridFactory: start = self.grid_positions(scene) + scene, transformations = self.random_transformations(scene) + for a in range(10): col, shp = scene col, shp = col.view(-1), shp.view(-1) @@ -142,25 +167,53 @@ class GridFactory: if len(false) >= self.nb_questions: break + # print(f"{a=}") + if a < 10: break true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]] false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]] - true = [(q, "yes") for q in true] - false = [(q, "no") for q in false] + true = [" " + q + " " for q in true] + false = [" " + q + " " for q in false] union = true + false questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]] - return scene, questions + result = " ".join( + [" " + x for x in self.grid_positions(scene)] + + transformations + + questions + ) + + return scene, result + + def generate_samples(self, nb, progress_bar=None): + result = [] + + r = range(nb) + if progress_bar is not None: + r = progress_bar(r) + + for _ in r: + result.append(self.generate_scene_and_questions()[1]) + + return result ###################################################################### if __name__ == "__main__": + import time + grid_factory = GridFactory() - scene, questions = grid_factory.generate_example() + + start_time = time.perf_counter() + samples = grid_factory.generate_samples(10000) + end_time = time.perf_counter() + print(f"{len(samples) / (end_time - start_time):.02f} samples per second") + + scene, questions = grid_factory.generate_scene_and_questions() grid_factory.print_scene(scene) print(questions)