)
if self.back_accuracy:
+ # If back_accuracy is True, we compute the accuracy on
+ # the backward quizzes not by counting how many time
+ # the real prompt A is equal to the reconstructed
+ # prompt A*, but how many time the answers B* computed
+ # from A* is equal to the correct answer. So we look
+ # for the accuracy of A->B*=B for the forward, but for
+ # the backward we look at B->A*->B*=B instead of B->A*=A
+
n_forward = input[:, 0] == self.token_forward
nb_total = input[n_forward].size(0)
nb_correct = (