def standard_validity(logproba):
l = logproba.sort(dim=-1).values
- return logical_and(l[0] < math.log(0.5), l[1] > math.log(0.95))
+ return (l[:, 0] < math.log(0.5)) & (l[:, 1] > math.log(0.95))
def valid_c_quizzes(recorded, criteria):
c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
if c_quizzes.size(0) > 0:
- logproba = c_quizzes.new(c_quizzes.size(0), len(models))
- for q, l in zip(
- c_quizzes.split(args.batch_size), logproba.split(args.batch_size)
- ):
- for model in models:
- l[model.id] = F.cross_entropy(model(q))
-
+ logproba = quiz_machine.logproba_solution(models, c_quizzes)
for l in logproba:
s = " ".join([str(x.item()) for x in l])
logp_file.write(s + "\n")
-
quizzes_and_logproba_records.append((c_quizzes, logproba))
nb_validated = valid_c_quizzes(
##################################################
# Replace a fraction of the w_quizzes with fresh ones
+ log_string(
+ f"cache_w_quizzes contains {quiz_machine.problem.nb_cached_quizzes()} quizzes"
+ )
quiz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
##################################################
else:
self.test_c_quizzes.append(new_c_quizzes)
+ def logproba_solution(self, models, c_quizzes):
+ logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
+
+ for model in models:
+ for input, l in zip(
+ c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+ ):
+ ar_mask = self.make_ar_mask(input)
+ output = model(mygpt.BracketedSequence(input)).x
+ ce = (
+ F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+ * ar_mask
+ )
+ l[:, model.id] = ce.sum(dim=-1)
+
+ return logproba
+
+ ###############################################################
+
def compute_correctness(
self,
c_quizzes,