X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=grid.py;h=1287ad52db6888610f2eb57589cfab977bcf3017;hb=HEAD;hp=60baedf727d891c745ac3a7b731424c08a89480b;hpb=3e3b9ead54130e5e3b2ce690943af9cb4c894e65;p=picoclvr.git diff --git a/grid.py b/grid.py index 60baedf..1287ad5 100755 --- a/grid.py +++ b/grid.py @@ -9,10 +9,6 @@ import math import torch, torchvision import torch.nn.functional as F -name_shapes = ["A", "B", "C", "D", "E", "F"] - -name_colors = ["red", "yellow", "blue", "green", "white", "purple"] - ###################################################################### @@ -23,52 +19,104 @@ class GridFactory: max_nb_items=4, max_nb_transformations=3, nb_questions=4, + nb_shapes=6, + nb_colors=6, + nb_play_steps=3, ): assert size % 2 == 0 self.size = size self.max_nb_items = max_nb_items self.max_nb_transformations = max_nb_transformations self.nb_questions = nb_questions + self.nb_play_steps = nb_play_steps + self.name_shapes = ["A", "B", "C", "D", "E", "F"] + self.name_colors = ["red", "yellow", "blue", "green", "white", "purple"] + self.vname_shapes = ["vA", "vB", "vC", "vD", "vE", "vF"] + self.vname_colors = ["vred", "vyellow", "vblue", "vgreen", "vwhite", "vpurple"] def generate_scene(self): nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2 col = torch.full((self.size * self.size,), -1) shp = torch.full((self.size * self.size,), -1) - a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items] - col[:nb_items] = a % len(name_colors) - shp[:nb_items] = a // len(name_colors) + a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items] + col[:nb_items] = a % len(self.name_colors) + shp[:nb_items] = a // len(self.name_colors) i = torch.randperm(self.size * self.size) col = col[i] shp = shp[i] return col.reshape(self.size, self.size), shp.reshape(self.size, self.size) - def random_transformations(self, scene): + def random_object_move(self, scene): + col, shp = scene + while True: + a = (col.flatten() >= 0).nonzero() + a = a[torch.randint(a.size(0), (1,)).item()] + i, j = a // self.size, a % self.size + assert col[i, j] >= 0 + dst = [(i, j), (i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)] + dst = list( + filter( + lambda x: x[0] >= 0 + and x[1] >= 0 + and x[0] < self.size + and x[1] < self.size + and col[x[0], x[1]] < 0, + dst, + ) + ) + if len(dst) > 0: + ni, nj = dst[torch.randint(len(dst), (1,)).item()] + col[ni, nj] = col[i, j] + shp[ni, nj] = shp[i, j] + col[i, j] = -1 + shp[i, j] = -1 + break + + return col, shp + + def transformation(self, t, scene): col, shp = scene + if t == 0: + col, shp = col.flip(0), shp.flip(0) + description = " vertical flip" + elif t == 1: + col, shp = col.flip(1), shp.flip(1) + description = " horizontal flip" + elif t == 2: + col, shp = col.flip(0).t(), shp.flip(0).t() + description = " rotate 90 degrees" + elif t == 3: + col, shp = col.flip(0).flip(1), shp.flip(0).flip(1) + description = " rotate 180 degrees" + elif t == 4: + col, shp = col.flip(1).t(), shp.flip(1).t() + description = " rotate 270 degrees" + + return (col.contiguous(), shp.contiguous()), description + def random_transformations(self, scene): descriptions = [] nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item() transformations = torch.randint(5, (nb_transformations,)) for t in transformations: - if t == 0: - col, shp = col.flip(0), shp.flip(0) - descriptions += [" vertical flip"] - elif t == 1: - col, shp = col.flip(1), shp.flip(1) - descriptions += [" horizontal flip"] - elif t == 2: - col, shp = col.flip(0).t(), shp.flip(0).t() - descriptions += [" rotate 90 degrees"] - elif t == 3: - col, shp = col.flip(0).flip(1), shp.flip(0).flip(1) - descriptions += [" rotate 180 degrees"] - elif t == 4: - col, shp = col.flip(1).t(), shp.flip(1).t() - descriptions += [" rotate 270 degrees"] - - col, shp = col.contiguous(), shp.contiguous() - - return (col, shp), descriptions + scene, description = self.transformation(t, scene) + descriptions += [description] + + return scene, descriptions + + def visual_scene2str(self, scene): + col, shp = scene + r = [] + for i in range(self.size): + s = [] + for j in range(self.size): + if col[i, j] >= 0: + s += [self.vname_colors[col[i, j]], self.vname_shapes[shp[i, j]]] + else: + s += ["v_", "v+"] + r += s # .append(" ".join(s)) + return " ".join(r) def print_scene(self, scene): col, shp = scene @@ -76,12 +124,15 @@ class GridFactory: # for i in range(self.size): # for j in range(self.size): # if col[i,j] >= 0: - # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}") + # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}") for i in range(self.size): for j in range(self.size): if col[i, j] >= 0: - print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="") + print( + f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}", + end="", + ) elif j == 0: print(" +", end="") else: @@ -103,7 +154,7 @@ class GridFactory: for i in range(self.size): for j in range(self.size): if col[i, j] >= 0: - n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}" + n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}" properties += [f"a {n} at {i} {j}"] return properties @@ -116,7 +167,9 @@ class GridFactory: for i1 in range(self.size): for j1 in range(self.size): if col[i1, j1] >= 0: - n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}" + n1 = ( + f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}" + ) properties += [f"there is a {n1}"] if i1 < self.size // 2: properties += [f"a {n1} is in the top half"] @@ -129,7 +182,7 @@ class GridFactory: for i2 in range(self.size): for j2 in range(self.size): if col[i2, j2] >= 0: - n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}" + n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}" if i1 > i2: properties += [f"a {n1} is below a {n2}"] if i1 < i2: @@ -143,8 +196,22 @@ class GridFactory: return properties + def generate_scene_and_play(self): + scene = self.generate_scene() + steps = [self.visual_scene2str(scene)] + for t in range(self.nb_play_steps - 1): + if torch.randint(4, (1,)).item() == 0: + scene, _ = self.transformation(torch.randint(5, (1,)), scene) + else: + scene = self.random_object_move(scene) + steps.append(self.visual_scene2str(scene)) + return " | ".join(steps) + def generate_scene_and_questions(self): while True: + # We generate scenes until we get one with enough + # properties + while True: start_scene = self.generate_scene() scene, transformations = self.random_transformations(start_scene) @@ -152,6 +219,12 @@ class GridFactory: if len(true) >= self.nb_questions: break + # We generate a bunch of false properties by shuffling the + # scene and sometimes adding properties from totally + # different scenes. We try ten times to get enough false + # properties and go back to generating the scene if we do + # not succeed + for a in range(10): col, shp = scene col, shp = col.view(-1), shp.view(-1) @@ -167,6 +240,7 @@ class GridFactory: # We sometime add properties from a totally different # scene to have negative "there is a xxx xxx" # properties + if torch.rand(1).item() < 0.2: other_scene = self.generate_scene() false += self.all_properties(other_scene) @@ -194,15 +268,18 @@ class GridFactory: return start_scene, scene, result - def generate_samples(self, nb, progress_bar=None): + def generate_samples(self, nb, fraction_play=0.0, progress_bar=None): result = [] - r = range(nb) + play = torch.rand(nb) < fraction_play if progress_bar is not None: - r = progress_bar(r) + play = progress_bar(play) - for _ in r: - result.append(self.generate_scene_and_questions()[2]) + for p in play: + if p: + result.append(self.generate_scene_and_play()) + else: + result.append(self.generate_scene_and_questions()[2]) return result @@ -220,11 +297,27 @@ if __name__ == "__main__": # print(f"{len(samples) / (end_time - start_time):.02f} samples per second") start_scene, scene, questions = grid_factory.generate_scene_and_questions() + print() print("-- Original scene -----------------------------") + print() grid_factory.print_scene(start_scene) + print() print("-- Transformed scene --------------------------") + print() grid_factory.print_scene(scene) + print() print("-- Sequence -----------------------------------") + print() print(questions) + # print(grid_factory.visual_scene2str(scene)) + + # grid_factory.print_scene(scene) + # for t in range(5): + # scene = grid_factory.random_object_move(scene) + # print() + # grid_factory.print_scene(scene) + + print(grid_factory.generate_scene_and_play()) + ######################################################################