X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=grid.py;h=2135710e68faa57af119b1e9193c32223a1fc2f5;hb=ac3d9ba45d72a7f3e399de4e3614698ac5e0ce39;hp=f72c8e34ce417e50daf6135d4dc73d6af44fca04;hpb=0e1e208852b83f6a3d59e5caabd2f0f1f4bde94e;p=picoclvr.git diff --git a/grid.py b/grid.py index f72c8e3..2135710 100755 --- a/grid.py +++ b/grid.py @@ -9,33 +9,34 @@ import math import torch, torchvision import torch.nn.functional as F -name_shapes = ["A", "B", "C", "D", "E", "F"] - -name_colors = ["red", "yellow", "blue", "green", "white", "purple"] - ###################################################################### class GridFactory: def __init__( self, - size=4, + size=6, max_nb_items=4, max_nb_transformations=3, nb_questions=4, + nb_shapes=6, + nb_colors=6, ): + assert size % 2 == 0 self.size = size self.max_nb_items = max_nb_items self.max_nb_transformations = max_nb_transformations self.nb_questions = nb_questions + self.name_shapes = ["A", "B", "C", "D", "E", "F"] + self.name_colors = ["red", "yellow", "blue", "green", "white", "purple"] def generate_scene(self): nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2 col = torch.full((self.size * self.size,), -1) shp = torch.full((self.size * self.size,), -1) - a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items] - col[:nb_items] = a % len(name_colors) - shp[:nb_items] = a // len(name_colors) + a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items] + col[:nb_items] = a % len(self.name_colors) + shp[:nb_items] = a // len(self.name_colors) i = torch.randperm(self.size * self.size) col = col[i] shp = shp[i] @@ -75,12 +76,15 @@ class GridFactory: # for i in range(self.size): # for j in range(self.size): # if col[i,j] >= 0: - # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}") + # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}") for i in range(self.size): for j in range(self.size): if col[i, j] >= 0: - print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="") + print( + f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}", + end="", + ) elif j == 0: print(" +", end="") else: @@ -102,7 +106,7 @@ class GridFactory: for i in range(self.size): for j in range(self.size): if col[i, j] >= 0: - n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}" + n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}" properties += [f"a {n} at {i} {j}"] return properties @@ -115,7 +119,9 @@ class GridFactory: for i1 in range(self.size): for j1 in range(self.size): if col[i1, j1] >= 0: - n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}" + n1 = ( + f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}" + ) properties += [f"there is a {n1}"] if i1 < self.size // 2: properties += [f"a {n1} is in the top half"] @@ -128,7 +134,7 @@ class GridFactory: for i2 in range(self.size): for j2 in range(self.size): if col[i2, j2] >= 0: - n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}" + n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}" if i1 > i2: properties += [f"a {n1} is below a {n2}"] if i1 < i2: @@ -137,23 +143,20 @@ class GridFactory: properties += [f"a {n1} is right of a {n2}"] if j1 < j2: properties += [f"a {n1} is left of a {n2}"] + if abs(i1 - i2) + abs(j1 - j2) == 1: + properties += [f"a {n1} is next to a {n2}"] return properties def generate_scene_and_questions(self): while True: while True: - scene = self.generate_scene() + start_scene = self.generate_scene() + scene, transformations = self.random_transformations(start_scene) true = self.all_properties(scene) if len(true) >= self.nb_questions: break - start = self.grid_positions(scene) - - scene, transformations = self.random_transformations(scene) - - # transformations=[] - for a in range(10): col, shp = scene col, shp = col.view(-1), shp.view(-1) @@ -163,8 +166,17 @@ class GridFactory: col.view(self.size, self.size), shp.view(self.size, self.size), ) - # other_scene = self.generate_scene() - false = list(set(self.all_properties(other_scene)) - set(true)) + + false = self.all_properties(other_scene) + + # We sometime add properties from a totally different + # scene to have negative "there is a xxx xxx" + # properties + if torch.rand(1).item() < 0.2: + other_scene = self.generate_scene() + false += self.all_properties(other_scene) + + false = list(set(false) - set(true)) if len(false) >= self.nb_questions: break @@ -173,19 +185,19 @@ class GridFactory: true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]] false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]] - true = [" " + q + " " for q in true] - false = [" " + q + " " for q in false] + true = [" " + q + " true" for q in true] + false = [" " + q + " false" for q in false] union = true + false questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]] result = " ".join( - [" " + x for x in self.grid_positions(scene)] + [" " + x for x in self.grid_positions(start_scene)] + transformations + questions ) - return scene, result + return start_scene, scene, result def generate_samples(self, nb, progress_bar=None): result = [] @@ -195,7 +207,7 @@ class GridFactory: r = progress_bar(r) for _ in r: - result.append(self.generate_scene_and_questions()[1]) + result.append(self.generate_scene_and_questions()[2]) return result @@ -207,13 +219,23 @@ if __name__ == "__main__": grid_factory = GridFactory() - start_time = time.perf_counter() - samples = grid_factory.generate_samples(10000) - end_time = time.perf_counter() - print(f"{len(samples) / (end_time - start_time):.02f} samples per second") - - scene, questions = grid_factory.generate_scene_and_questions() + # start_time = time.perf_counter() + # samples = grid_factory.generate_samples(10000) + # end_time = time.perf_counter() + # print(f"{len(samples) / (end_time - start_time):.02f} samples per second") + + start_scene, scene, questions = grid_factory.generate_scene_and_questions() + print() + print("-- Original scene -----------------------------") + print() + grid_factory.print_scene(start_scene) + print() + print("-- Transformed scene --------------------------") + print() grid_factory.print_scene(scene) + print() + print("-- Sequence -----------------------------------") + print() print(questions) ######################################################################