###############################################################
- def optimize_quizzes(self, quiz, nb_variants, nb_iterations, struct, mask):
+ def optimize_quizzes(
+ self, models, quiz, nb_variants, nb_iterations, struct, mask, proba_understands
+ ):
for _ in range(nb_iterations):
- candidates = quizzes[None].expand(nb_variants, -1)
+ candidates = quiz[None, :].expand(nb_variants, -1).clone()
r = torch.rand(candidates.size(), device=candidates.device)
u = r.reshape(r.size(0), 4, candidates.size(1) // 4)
# Only change the part indicated by the mask and do not
# touch the special tokens
u[:, :, 0] = 0
- u = u * torch.tensor(mask, device=u.device)[None, :, None]
- random_mask = (r.sort(dim=0, descending=True).indices == 0).long()
+ u = u * (1 - torch.tensor(mask, device=u.device)[None, :, None])
+ random_mask = F.one_hot(r.argmax(dim=1), num_classes=r.size(1))
# Keep the first unchanged
- random_mask[:, 0, :] = 0
+ random_mask[0, :] = 0
# Reshape without the 4 parts
candidates.reshape(-1, candidates.size(-1))
random_mask.reshape(candidates.size())
random_tokens = torch.randint(
- self.problem.nb_token_values - 4, random_mask.size()
+ self.problem.nb_token_values - 4,
+ random_mask.size(),
+ device=candidates.device,
)
# Apply the noise
candidates = (1 - random_mask) * candidates + random_mask * random_tokens
- seq_logproba = quiz_machine.models_logprobas(
+ seq_logproba = self.models_logprobas(
models, candidates, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
- ) + quiz_machine.models_logprobas(
+ ) + self.models_logprobas(
models, candidates, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
)
sorted_logprobas = seq_logproba.sort(dim=1).values.exp()
lowest, second_lowest = sorted_logprobas[:, 0], sorted_logprobas[:, 1]
score = second_lowest - lowest
- score = score * (second_lowest > args.proba_understands)
+ # score = score * (second_lowest > proba_understands)
quiz = candidates[score.argmax()]
+ print(score.max())
- return quiz
+ return quiz.to("cpu")
def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None):
seq_logproba = torch.zeros(nb, device=self.device)