From: François Fleuret Date: Fri, 25 Aug 2023 16:58:43 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=ce969e8372fb161d86be29042a20b044ee6efe2a;p=picoclvr.git Update. --- diff --git a/grid.py b/grid.py index 70f7739..433cfd5 100755 --- a/grid.py +++ b/grid.py @@ -28,6 +28,7 @@ class GridFactory: self.height = height self.width = width self.max_nb_items = max_nb_items + self.max_nb_transformations = max_nb_transformations self.nb_questions = nb_questions def generate_scene(self): @@ -44,8 +45,30 @@ class GridFactory: self.height, self.width ) - def random_transformations(self): + def random_transformations(self, scene): + col, shp = scene + descriptions = [] nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item() + transformations = torch.randint(5, (nb_transformations,)) + + for t in transformations: + if t == 0: + col, shp = col.flip(0), shp.flip(0) + descriptions += [" vertical flip"] + elif t == 1: + col, shp = col.flip(1), shp.flip(1) + descriptions += [" horizontal flip"] + elif t == 2: + col, shp = col.flip(0).t(), shp.flip(0).t() + descriptions += [" rotate 90 degrees"] + elif t == 3: + col, shp = col.flip(0).flip(1), shp.flip(0).flip(1) + descriptions += [" rotate 180 degrees"] + elif t == 4: + col, shp = col.flip(1).t(), shp.flip(1).t() + descriptions += [" rotate 270 degrees"] + + return (col.contiguous(), shp.contiguous()), descriptions def print_scene(self, scene): col, shp = scene @@ -128,6 +151,8 @@ class GridFactory: start = self.grid_positions(scene) + scene, transformations = self.random_transformations(scene) + for a in range(10): col, shp = scene col, shp = col.view(-1), shp.view(-1) @@ -156,7 +181,9 @@ class GridFactory: questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]] result = " ".join( - [" " + x for x in self.grid_positions(scene)] + questions + [" " + x for x in self.grid_positions(scene)] + + transformations + + questions ) return scene, result diff --git a/tasks.py b/tasks.py index c7348d5..0ab1823 100755 --- a/tasks.py +++ b/tasks.py @@ -1429,7 +1429,7 @@ class Grid(Task): def tensorize(self, descr): token_descr = [s.strip().split(" ") for s in descr] l = max([len(s) for s in token_descr]) - token_descr = [s + [""] * (l - len(s)) for s in token_descr] + token_descr = [s + ["#"] * (l - len(s)) for s in token_descr] id_descr = [[self.token2id[u] for u in s] for s in token_descr] return torch.tensor(id_descr, device=self.device) @@ -1440,7 +1440,7 @@ class Grid(Task): # trim all the tensors in the tuple z to remove as much token from # left and right in the first tensor. If z is a tuple, all its # elements are trimed according to the triming for the first - def trim(self, z, token=""): + def trim(self, z, token="#"): n = self.token2id[token] if type(z) == tuple: x = z[0] @@ -1483,7 +1483,7 @@ class Grid(Task): ) # Build the tokenizer - tokens = {} + tokens = set() for d in [self.train_descr, self.test_descr]: for s in d: for t in s.strip().split(" "): @@ -1492,10 +1492,10 @@ class Grid(Task): # the same descr tokens = list(tokens) tokens.sort() - tokens = [""] + tokens + tokens = ["#"] + tokens self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) - self.t_nul = self.token2id[""] + self.t_nul = self.token2id["#"] self.t_true = self.token2id[""] self.t_false = self.token2id[""]