# We discard the trivial ones, according to a criterion
# specific to the world quizzes (e.g. B=f(B))
- c_quizzes = c_quizzes[quiz_machine.problem.trivial(c_quizzes) == False]
+ rejected = []
+
+ to_keep == quiz_machine.problem.trivial(c_quizzes) == False
+
+ if not to_keep.all():
+ rejected.append(c_quizzes[to_keep == False])
+
+ c_quizzes = c_quizzes[to_keep]
# We go through nb_rounds rounds and keep only quizzes on
# which
number_correct_responses = 0
nb_remaining = [c_quizzes.size(0)]
- rejected = []
for r in range(args.nb_rounds):
if c_quizzes.size(0) == 0:
######################################################################
def autoregression(
+ self,
model,
input,
ar_mask,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
logit_transformer=logit_transformer,
- deterministic_synthesis=deterministic_synthesis,
+ deterministic_synthesis=False,
)
model.train(t)
ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
result = quizzes * (1 - ar_mask)
+ seq_logproba = torch.empty(quizzes.size(0), device=self.device)
+
self.autoregression(
model=model,
input=result,