From 128d372813e99d8474bb6e967d5c7e7f085c819d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 6 Feb 2024 15:15:21 +0100 Subject: [PATCH] Update. --- grid.py | 132 ++++++++++++++++++++++++++++++++++++++++++++----------- main.py | 3 ++ tasks.py | 7 ++- 3 files changed, 115 insertions(+), 27 deletions(-) diff --git a/grid.py b/grid.py index 2135710..1287ad5 100755 --- a/grid.py +++ b/grid.py @@ -21,14 +21,18 @@ class GridFactory: nb_questions=4, nb_shapes=6, nb_colors=6, + nb_play_steps=3, ): assert size % 2 == 0 self.size = size self.max_nb_items = max_nb_items self.max_nb_transformations = max_nb_transformations self.nb_questions = nb_questions + self.nb_play_steps = nb_play_steps self.name_shapes = ["A", "B", "C", "D", "E", "F"] self.name_colors = ["red", "yellow", "blue", "green", "white", "purple"] + self.vname_shapes = ["vA", "vB", "vC", "vD", "vE", "vF"] + self.vname_colors = ["vred", "vyellow", "vblue", "vgreen", "vwhite", "vpurple"] def generate_scene(self): nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2 @@ -42,33 +46,77 @@ class GridFactory: shp = shp[i] return col.reshape(self.size, self.size), shp.reshape(self.size, self.size) - def random_transformations(self, scene): + def random_object_move(self, scene): col, shp = scene + while True: + a = (col.flatten() >= 0).nonzero() + a = a[torch.randint(a.size(0), (1,)).item()] + i, j = a // self.size, a % self.size + assert col[i, j] >= 0 + dst = [(i, j), (i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)] + dst = list( + filter( + lambda x: x[0] >= 0 + and x[1] >= 0 + and x[0] < self.size + and x[1] < self.size + and col[x[0], x[1]] < 0, + dst, + ) + ) + if len(dst) > 0: + ni, nj = dst[torch.randint(len(dst), (1,)).item()] + col[ni, nj] = col[i, j] + shp[ni, nj] = shp[i, j] + col[i, j] = -1 + shp[i, j] = -1 + break + + return col, shp + def transformation(self, t, scene): + col, shp = scene + if t == 0: + col, shp = col.flip(0), shp.flip(0) + description = " vertical flip" + elif t == 1: + col, shp = col.flip(1), shp.flip(1) + description = " horizontal flip" + elif t == 2: + col, shp = col.flip(0).t(), shp.flip(0).t() + description = " rotate 90 degrees" + elif t == 3: + col, shp = col.flip(0).flip(1), shp.flip(0).flip(1) + description = " rotate 180 degrees" + elif t == 4: + col, shp = col.flip(1).t(), shp.flip(1).t() + description = " rotate 270 degrees" + + return (col.contiguous(), shp.contiguous()), description + + def random_transformations(self, scene): descriptions = [] nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item() transformations = torch.randint(5, (nb_transformations,)) for t in transformations: - if t == 0: - col, shp = col.flip(0), shp.flip(0) - descriptions += [" vertical flip"] - elif t == 1: - col, shp = col.flip(1), shp.flip(1) - descriptions += [" horizontal flip"] - elif t == 2: - col, shp = col.flip(0).t(), shp.flip(0).t() - descriptions += [" rotate 90 degrees"] - elif t == 3: - col, shp = col.flip(0).flip(1), shp.flip(0).flip(1) - descriptions += [" rotate 180 degrees"] - elif t == 4: - col, shp = col.flip(1).t(), shp.flip(1).t() - descriptions += [" rotate 270 degrees"] - - col, shp = col.contiguous(), shp.contiguous() - - return (col, shp), descriptions + scene, description = self.transformation(t, scene) + descriptions += [description] + + return scene, descriptions + + def visual_scene2str(self, scene): + col, shp = scene + r = [] + for i in range(self.size): + s = [] + for j in range(self.size): + if col[i, j] >= 0: + s += [self.vname_colors[col[i, j]], self.vname_shapes[shp[i, j]]] + else: + s += ["v_", "v+"] + r += s # .append(" ".join(s)) + return " ".join(r) def print_scene(self, scene): col, shp = scene @@ -148,8 +196,22 @@ class GridFactory: return properties + def generate_scene_and_play(self): + scene = self.generate_scene() + steps = [self.visual_scene2str(scene)] + for t in range(self.nb_play_steps - 1): + if torch.randint(4, (1,)).item() == 0: + scene, _ = self.transformation(torch.randint(5, (1,)), scene) + else: + scene = self.random_object_move(scene) + steps.append(self.visual_scene2str(scene)) + return " | ".join(steps) + def generate_scene_and_questions(self): while True: + # We generate scenes until we get one with enough + # properties + while True: start_scene = self.generate_scene() scene, transformations = self.random_transformations(start_scene) @@ -157,6 +219,12 @@ class GridFactory: if len(true) >= self.nb_questions: break + # We generate a bunch of false properties by shuffling the + # scene and sometimes adding properties from totally + # different scenes. We try ten times to get enough false + # properties and go back to generating the scene if we do + # not succeed + for a in range(10): col, shp = scene col, shp = col.view(-1), shp.view(-1) @@ -172,6 +240,7 @@ class GridFactory: # We sometime add properties from a totally different # scene to have negative "there is a xxx xxx" # properties + if torch.rand(1).item() < 0.2: other_scene = self.generate_scene() false += self.all_properties(other_scene) @@ -199,15 +268,18 @@ class GridFactory: return start_scene, scene, result - def generate_samples(self, nb, progress_bar=None): + def generate_samples(self, nb, fraction_play=0.0, progress_bar=None): result = [] - r = range(nb) + play = torch.rand(nb) < fraction_play if progress_bar is not None: - r = progress_bar(r) + play = progress_bar(play) - for _ in r: - result.append(self.generate_scene_and_questions()[2]) + for p in play: + if p: + result.append(self.generate_scene_and_play()) + else: + result.append(self.generate_scene_and_questions()[2]) return result @@ -238,4 +310,14 @@ if __name__ == "__main__": print() print(questions) + # print(grid_factory.visual_scene2str(scene)) + + # grid_factory.print_scene(scene) + # for t in range(5): + # scene = grid_factory.random_object_move(scene) + # print() + # grid_factory.print_scene(scene) + + print(grid_factory.generate_scene_and_play()) + ###################################################################### diff --git a/main.py b/main.py index 69731ff..9f82594 100755 --- a/main.py +++ b/main.py @@ -104,6 +104,8 @@ parser.add_argument("--rpl_no_prog", action="store_true", default=False) parser.add_argument("--grid_size", type=int, default=6) +parser.add_argument("--grid_fraction_play", type=float, default=0) + ############################## # picoclvr options @@ -554,6 +556,7 @@ elif args.task == "grid": nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, size=args.grid_size, + fraction_play=args.grid_fraction_play, logger=log_string, device=device, ) diff --git a/tasks.py b/tasks.py index a53d213..08aa8ca 100755 --- a/tasks.py +++ b/tasks.py @@ -1475,6 +1475,7 @@ class Grid(Task): nb_test_samples, batch_size, size, + fraction_play=0.0, logger=None, device=torch.device("cpu"), ): @@ -1490,10 +1491,12 @@ class Grid(Task): ) self.train_descr = self.grid_factory.generate_samples( - nb_train_samples, lambda r: tqdm.tqdm(r) + nb=nb_train_samples, + fraction_play=fraction_play, + progress_bar=lambda r: tqdm.tqdm(r), ) self.test_descr = self.grid_factory.generate_samples( - nb_test_samples, lambda r: tqdm.tqdm(r) + nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r) ) # Build the tokenizer -- 2.39.5