vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
if vq.size(0) > 0:
+ vq = quiz_machine.problem.reconfigure(vq, ("A", "f_A", "B", "f_B"))
number_correct_responses = 0
for r in tqdm.tqdm(range(10), dynamic_ncols=True, desc="re-scoring c_quizzes"):
number_correct_responses += quiz_machine.models_successes(models, vq)
+ seq_logproba = quiz_machine.models_logprobas(models, vq)
+
comments = []
- for r in number_correct_responses:
- comments.append("nb_correct " + " ".join([str(n.item()) for n in r]))
- vq = quiz_machine.problem.reconfigure(vq, ("A", "f_A", "B", "f_B"))
+ for l, r in zip(seq_logproba, number_correct_responses):
+ comments.append(
+ "nb_correct "
+ + " ".join([str(n.item()) for n in r])
+ + "\n"
+ + "proba "
+ + " ".join([str(x.item()) for x in l])
+ )
+
filename = f"culture_c_quiz_{n_epoch:04d}.png"
quiz_machine.problem.save_quizzes_as_image(
args.result_dir, filename, vq, comments=comments
model.main_test_accuracy = 0.0
##################################################
- # Select, improve, and eval the worst model
+ # Select, improve, and eval the worst model(s)
ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
######################################################################
- def solution_token_logprobas(self, models, c_quizzes):
- logproba = c_quizzes.new_zeros(
+ def models_logprobas(self, models_for_validation, c_quizzes, device=None):
+ if device is None:
+ device = self.device
+
+ c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
+
+ seq_logproba = torch.zeros(
c_quizzes.size(0),
- len(models),
- c_quizzes.size(1),
- device=self.device,
- dtype=torch.float32,
+ max([m.id for m in models_for_validation]) + 1,
+ device=device,
)
- for model in models:
+ for model in models_for_validation:
with torch.autograd.no_grad():
t = model.training
model.eval()
for input, l in zip(
- c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+ c_quizzes.split(self.batch_size),
+ seq_logproba.split(self.batch_size),
):
- input = input.to(self.device)
- ar_mask = self.make_ar_mask(input, shape="fwd_3_bck_123")
+ input = input.to(device)
+ ar_mask = self.make_ar_mask(input)
output = model(mygpt.BracketedSequence(input)).x
l[:, model.id] = (
-F.cross_entropy(
output.transpose(1, 2), input, reduction="none"
)
* ar_mask
- )
+ ).sum()
model.train(t)
- return logproba.to("cpu")
+ return seq_logproba.to("cpu")
###############################################################
- def models_successes(self, models_for_validation, c_quizzes):
+ def models_successes(self, models_for_validation, c_quizzes, device=None):
+ if device is None:
+ device = self.device
+
+ c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
+
seq_logproba = torch.zeros(
c_quizzes.size(0),
max([m.id for m in models_for_validation]) + 1,
- device=self.device,
+ device=device,
)
correctly_solved = torch.empty(
c_quizzes.size(0),
max([m.id for m in models_for_validation]) + 1,
- device=self.device,
+ device=device,
dtype=torch.int64,
)
seq_logproba[...] = 0.0
- c_quizzes = c_quizzes.to(self.device)
+ c_quizzes = c_quizzes.to(device)
reversed_c_quizzes = self.problem.reconfigure(
c_quizzes, ("f_A", "A", "f_B", "B")