parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
-parser.add_argument("--proba_understands", type=float, default=0.9)
+parser.add_argument("--proba_understands", type=float, default=0.99)
parser.add_argument("--proba_not_understands", type=float, default=0.5)
# This is nb_quizzes x nb_models
- number_correct_responses = 0
- nb_remaining = [c_quizzes.size(0)]
-
- for r in range(args.nb_rounds):
- if c_quizzes.size(0) == 0:
- break
-
- number_correct_responses += quiz_machine.models_successes(models, c_quizzes)
-
- nb_sure_correct = (number_correct_responses == r + 1).long().sum(dim=1)
- nb_sure_fail = (number_correct_responses == 0).long().sum(dim=1)
-
- to_keep = (
- (nb_sure_correct + nb_sure_fail == number_correct_responses.size(1))
- & (nb_sure_fail >= 1)
- & (nb_sure_fail <= args.max_fail_to_validate)
- )
-
- if not to_keep.all():
- rejected.append(c_quizzes[to_keep == False])
+ seq_logproba = quiz_machine.models_logprobas(
+ models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+ ) + quiz_machine.models_logprobas(
+ models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+ )
- c_quizzes = c_quizzes[to_keep]
- number_correct_responses = number_correct_responses[to_keep]
+ probas = seq_logproba.exp()
+ nb_sure_correct = (probas >= args.proba_understands).long().sum(dim=1)
+ nb_sure_fail = (probas <= args.proba_understands).long().sum(dim=1)
- nb_remaining.append(c_quizzes.size(0))
+ to_keep = (
+ (nb_sure_correct + nb_sure_fail == probas.size(1))
+ & (nb_sure_fail >= 1)
+ & (nb_sure_fail <= args.max_fail_to_validate)
+ )
- to_recycle = torch.cat(rejected, dim=0) if len(rejected) > 0 else None
+ to_recycle = c_quizzes[to_keep == False] if not to_keep.all() else None
+ c_quizzes = c_quizzes[to_keep]
if c_quizzes.size(0) > 0:
nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
else:
e = "???"
- v = " ".join([str(n) for n in nb_remaining])
-
log_string(
- f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h) filtering {v}"
+ f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
)
validated_quizzes = torch.cat(recorded_validated, dim=0)
vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
if vq.size(0) > 0:
- vq = quiz_machine.problem.reconfigure(vq, ("A", "f_A", "B", "f_B"))
- number_correct_responses = 0
-
- for r in tqdm.tqdm(range(10), dynamic_ncols=True, desc="re-scoring c_quizzes"):
- number_correct_responses += quiz_machine.models_successes(models, vq)
-
- seq_logproba = quiz_machine.models_logprobas(models, vq)
+ seq_logproba = quiz_machine.models_logprobas(
+ models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+ ) + quiz_machine.models_logprobas(
+ models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+ )
comments = []
- for l, r in zip(seq_logproba, number_correct_responses):
- comments.append(
- "nb_correct "
- + " ".join([str(n.item()) for n in r])
- + "\n"
- + "proba "
- + " ".join([str(x.item()) for x in l])
- )
+ for l in seq_logproba:
+ comments.append(+"proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
filename = f"culture_c_quiz_{n_epoch:04d}.png"
quiz_machine.problem.save_quizzes_as_image(
######################################################################
- def models_logprobas(self, models_for_validation, c_quizzes, device=None):
+ def models_logprobas(
+ self, models_for_validation, c_quizzes, struct, mask, device=None
+ ):
if device is None:
device = self.device
- c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
+ c_quizzes = self.problem.reconfigure(c_quizzes, struct)
seq_logproba = torch.zeros(
c_quizzes.size(0),
seq_logproba.split(self.batch_size),
):
input = input.to(device)
- ar_mask = self.make_ar_mask(input)
+ ar_mask = self.make_ar_mask(input, struct, mask)
output = model(mygpt.BracketedSequence(input)).x
l[:, model.id] = (
-F.cross_entropy(
output.transpose(1, 2), input, reduction="none"
)
* ar_mask
- ).sum()
+ ).sum(dim=1)
model.train(t)
###############################################################
+ def optimize_quizzes(self, quizzes, nb_variants, nb_iterations, struct, mask):
+ for _ in range(nb_iterations):
+ candidates = quizzes[:, None].expand(-1, nb_variants, -1)
+ r = torch.rand(candidates.size(), device=candidates.device)
+ u = r.reshape(
+ candidates.size(0) * candidates.size(1), 4, candidates.size(2) // 4
+ )
+ u[:, :, 0] = 0
+ u = u * torch.tensor(mask, device=u.device)[None, :, None]
+ random_mask = (r.sort(dim=0, descending=True).indices == 0).long()
+ random_mask[:, 0] = 0
+ candidates.reshape(-1, candidates.size(-1))
+ random_mask.reshape(candidates.size())
+ random_tokens = torch.randint(
+ self.problem.nb_token_values - 4, random_mask.size()
+ )
+ candidates = (1 - random_mask) * candidates + random_mask * random_tokens
+ ar_mask = (self.make_ar_mask(candidates, struct, make_ar_mask),)
+
def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None):
seq_logproba = torch.zeros(nb, device=self.device)