From 3e3b9ead54130e5e3b2ce690943af9cb4c894e65 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 26 Aug 2023 00:47:09 +0200 Subject: [PATCH] Update. --- grid.py | 31 +++++++++++++++++++------------ tasks.py | 4 ++-- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/grid.py b/grid.py index 5b28914..60baedf 100755 --- a/grid.py +++ b/grid.py @@ -24,6 +24,7 @@ class GridFactory: max_nb_transformations=3, nb_questions=4, ): + assert size % 2 == 0 self.size = size self.max_nb_items = max_nb_items self.max_nb_transformations = max_nb_transformations @@ -137,6 +138,8 @@ class GridFactory: properties += [f"a {n1} is right of a {n2}"] if j1 < j2: properties += [f"a {n1} is left of a {n2}"] + if abs(i1 - i2) + abs(j1 - j2) == 1: + properties += [f"a {n1} is next to a {n2}"] return properties @@ -144,16 +147,11 @@ class GridFactory: while True: while True: start_scene = self.generate_scene() - true = self.all_properties(start_scene) + scene, transformations = self.random_transformations(start_scene) + true = self.all_properties(scene) if len(true) >= self.nb_questions: break - start = self.grid_positions(start_scene) - - scene, transformations = self.random_transformations(start_scene) - - # transformations=[] - for a in range(10): col, shp = scene col, shp = col.view(-1), shp.view(-1) @@ -163,8 +161,17 @@ class GridFactory: col.view(self.size, self.size), shp.view(self.size, self.size), ) - # other_scene = self.generate_scene() - false = list(set(self.all_properties(other_scene)) - set(true)) + + false = self.all_properties(other_scene) + + # We sometime add properties from a totally different + # scene to have negative "there is a xxx xxx" + # properties + if torch.rand(1).item() < 0.2: + other_scene = self.generate_scene() + false += self.all_properties(other_scene) + + false = list(set(false) - set(true)) if len(false) >= self.nb_questions: break @@ -173,14 +180,14 @@ class GridFactory: true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]] false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]] - true = [" " + q + " " for q in true] - false = [" " + q + " " for q in false] + true = [" " + q + " true" for q in true] + false = [" " + q + " false" for q in false] union = true + false questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]] result = " ".join( - [" " + x for x in self.grid_positions(scene)] + [" " + x for x in self.grid_positions(start_scene)] + transformations + questions ) diff --git a/tasks.py b/tasks.py index 24c13fe..cbc8e6b 100755 --- a/tasks.py +++ b/tasks.py @@ -1495,8 +1495,8 @@ class Grid(Task): self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) self.t_nul = self.token2id["#"] - self.t_true = self.token2id[""] - self.t_false = self.token2id[""] + self.t_true = self.token2id["true"] + self.t_false = self.token2id["false"] # Tokenize the train and test sets self.train_input = self.tensorize(self.train_descr) -- 2.39.5