Added figures
[culture.git] / quizz_machine.py
index 198d279..697f27e 100755 (executable)
@@ -139,6 +139,7 @@ class QuizzMachine:
         self.train_w_quizzes = self.problem.generate_token_sequences(
             nb_train_samples
         ).to(device)
+
         self.test_w_quizzes = self.problem.generate_token_sequences(nb_test_samples).to(
             device
         )
@@ -287,7 +288,7 @@ class QuizzMachine:
         return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1)
 
     def compute_correctness(
-        self, c_quizzes, models_for_validation, both_direction=True
+        self, c_quizzes, models_for_validation, both_directions=True
     ):
         reversed_c_quizzes = self.reverse_time(c_quizzes)
 
@@ -315,7 +316,7 @@ class QuizzMachine:
 
             correct = (c_quizzes == result).long().min(dim=-1).values
 
-            if both_direction:
+            if both_directions:
                 reversed_result = reversed_c_quizzes.clone()
 
                 masked_inplace_autoregression(