max_nb_transformations=3,
nb_questions=4,
):
+ assert size % 2 == 0
self.size = size
self.max_nb_items = max_nb_items
self.max_nb_transformations = max_nb_transformations
properties += [f"a {n1} is right of a {n2}"]
if j1 < j2:
properties += [f"a {n1} is left of a {n2}"]
+ if abs(i1 - i2) + abs(j1 - j2) == 1:
+ properties += [f"a {n1} is next to a {n2}"]
return properties
while True:
while True:
start_scene = self.generate_scene()
- true = self.all_properties(start_scene)
+ scene, transformations = self.random_transformations(start_scene)
+ true = self.all_properties(scene)
if len(true) >= self.nb_questions:
break
- start = self.grid_positions(start_scene)
-
- scene, transformations = self.random_transformations(start_scene)
-
- # transformations=[]
-
for a in range(10):
col, shp = scene
col, shp = col.view(-1), shp.view(-1)
col.view(self.size, self.size),
shp.view(self.size, self.size),
)
- # other_scene = self.generate_scene()
- false = list(set(self.all_properties(other_scene)) - set(true))
+
+ false = self.all_properties(other_scene)
+
+ # We sometime add properties from a totally different
+ # scene to have negative "there is a xxx xxx"
+ # properties
+ if torch.rand(1).item() < 0.2:
+ other_scene = self.generate_scene()
+ false += self.all_properties(other_scene)
+
+ false = list(set(false) - set(true))
if len(false) >= self.nb_questions:
break
true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
- true = ["<prop> " + q + " <true>" for q in true]
- false = ["<prop> " + q + " <false>" for q in false]
+ true = ["<prop> " + q + " <ans> true" for q in true]
+ false = ["<prop> " + q + " <ans> false" for q in false]
union = true + false
questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
result = " ".join(
- ["<obj> " + x for x in self.grid_positions(scene)]
+ ["<obj> " + x for x in self.grid_positions(start_scene)]
+ transformations
+ questions
)
self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
self.t_nul = self.token2id["#"]
- self.t_true = self.token2id["<true>"]
- self.t_false = self.token2id["<false>"]
+ self.t_true = self.token2id["true"]
+ self.t_false = self.token2id["false"]
# Tokenize the train and test sets
self.train_input = self.tensorize(self.train_descr)