else:
self.test_c_quizzes.append(new_c_quizzes)
+ def logproba_solution(self, models, c_quizzes):
+ logproba = c_quizzes.new_zeros(c_quizzes.size(0), len(models))
+
+ for model in models:
+ for input, l in zip(
+ c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+ ):
+ ar_mask = self.make_ar_mask(input)
+ output = model(mygpt.BracketedSequence(input)).x
+ ce = (
+ F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+ * ar_mask
+ )
+ l[:, model.id] = ce.sum(dim=-1)
+
+ return logproba
+
+ ###############################################################
+
def compute_correctness(
self,
c_quizzes,