self.check_structure(quizzes, struct)
return struct
+ # What a mess
def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
if torch.is_tensor(quizzes):
return self.reconfigure([quizzes], struct=struct)[0]
return result
- def non_trivial(self, quizzes):
+ def trivial(self, quizzes):
S = self.height * self.width
assert self.check_structure(quizzes, struct=("A", "f_A", "B", "f_B"))
a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
- return (a[:, 0] == a[:, 1]).min(dim=1).values & (a[:, 2] == a[:, 3]).min(
+ return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min(
dim=1
).values
# We discard the trivial ones, according to a criterion
# specific to the world quizzes (e.g. B=f(B))
- c_quizzes = c_quizzes[quiz_machine.problem.non_trivial(c_quizzes)]
+ c_quizzes = c_quizzes[quiz_machine.problem.trivial(c_quizzes) == False]
# We go through nb_rounds rounds and keep only quizzes on
# which
######################################################################
- def randomize_configuations_inplace(self, quizzes, configurations):
- r = torch.randint(
- len(configurations), (quizzes.size(0),), device=quizzes.device
- )
+ def randomize_configuations_inplace(self, quizzes, structs):
+ r = torch.randint(len(structs), (quizzes.size(0),), device=quizzes.device)
- for c in range(len(configurations)):
+ for c in range(len(structs)):
quizzes[r == c] = self.problem.reconfigure(
- quizzes[r == c], struct=configurations[c]
+ quizzes[r == c], struct=structs[c]
)
def create_w_quizzes(self, model, nb_train_samples, nb_test_samples):
model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples)
self.randomize_configuations_inplace(
- model.train_w_quizzes, configurations=self.train_struct
+ model.train_w_quizzes, structs=self.train_struct
)
self.randomize_configuations_inplace(
- model.test_w_quizzes, configurations=self.train_struct
+ model.test_w_quizzes, structs=self.train_struct
)
######################################################################
)
self.randomize_configuations_inplace(
- model.train_w_quizzes, configurations=self.train_struct
+ model.train_w_quizzes, structs=self.train_struct
)
######################################################################