self.height = height
self.width = width
self.max_nb_items = max_nb_items
+ self.max_nb_transformations = max_nb_transformations
self.nb_questions = nb_questions
def generate_scene(self):
self.height, self.width
)
- def random_transformations(self):
+ def random_transformations(self, scene):
+ col, shp = scene
+ descriptions = []
nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
+ transformations = torch.randint(5, (nb_transformations,))
+
+ for t in transformations:
+ if t == 0:
+ col, shp = col.flip(0), shp.flip(0)
+ descriptions += ["<chg> vertical flip"]
+ elif t == 1:
+ col, shp = col.flip(1), shp.flip(1)
+ descriptions += ["<chg> horizontal flip"]
+ elif t == 2:
+ col, shp = col.flip(0).t(), shp.flip(0).t()
+ descriptions += ["<chg> rotate 90 degrees"]
+ elif t == 3:
+ col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
+ descriptions += ["<chg> rotate 180 degrees"]
+ elif t == 4:
+ col, shp = col.flip(1).t(), shp.flip(1).t()
+ descriptions += ["<chg> rotate 270 degrees"]
+
+ return (col.contiguous(), shp.contiguous()), descriptions
def print_scene(self, scene):
col, shp = scene
start = self.grid_positions(scene)
+ scene, transformations = self.random_transformations(scene)
+
for a in range(10):
col, shp = scene
col, shp = col.view(-1), shp.view(-1)
questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
result = " ".join(
- ["<obj> " + x for x in self.grid_positions(scene)] + questions
+ ["<obj> " + x for x in self.grid_positions(scene)]
+ + transformations
+ + questions
)
return scene, result
def tensorize(self, descr):
token_descr = [s.strip().split(" ") for s in descr]
l = max([len(s) for s in token_descr])
- token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
+ token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
id_descr = [[self.token2id[u] for u in s] for s in token_descr]
return torch.tensor(id_descr, device=self.device)
# trim all the tensors in the tuple z to remove as much token from
# left and right in the first tensor. If z is a tuple, all its
# elements are trimed according to the triming for the first
- def trim(self, z, token="<nul>"):
+ def trim(self, z, token="#"):
n = self.token2id[token]
if type(z) == tuple:
x = z[0]
)
# Build the tokenizer
- tokens = {}
+ tokens = set()
for d in [self.train_descr, self.test_descr]:
for s in d:
for t in s.strip().split(" "):
# the same descr
tokens = list(tokens)
tokens.sort()
- tokens = ["<nul>"] + tokens
+ tokens = ["#"] + tokens
self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
- self.t_nul = self.token2id["<nul>"]
+ self.t_nul = self.token2id["#"]
self.t_true = self.token2id["<true>"]
self.t_false = self.token2id["<false>"]