Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 19 Jul 2024 06:21:17 +0000 (08:21 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 19 Jul 2024 06:21:17 +0000 (08:21 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index ab87b56..d9257db 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -376,7 +376,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
         acc_train_loss += loss.item() * input.size(0)
 
         loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
-        n_p2a = input[:, 0] == quiz_machine.token_p2a
+        n_p2a = input[:, 0] == quiz_machine.problem.token_forward
         to_store = from_w & n_p2a.to("cpu")
         if to_store.any():
             hard_w_quizzes.append(
@@ -496,11 +496,11 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
     v_train = validated_quizzes[:nb_for_train]
     quiz_machine.store_c_quizzes(v_train, for_train=True)
-    quiz_machine.store_c_quizzes(quiz_machine.reverse_time(v_train), for_train=True)
+    quiz_machine.store_c_quizzes(quiz_machine.p_a_flip(v_train), for_train=True)
 
     v_test = validated_quizzes[nb_for_train:nb_to_create]
     quiz_machine.store_c_quizzes(v_test, for_train=False)
-    quiz_machine.store_c_quizzes(quiz_machine.reverse_time(v_test), for_train=False)
+    quiz_machine.store_c_quizzes(quiz_machine.p_a_flip(v_test), for_train=False)
 
     ######################################################################
     # save images
index 51c3f08..cc81086 100755 (executable)
@@ -281,8 +281,8 @@ class QuizMachine:
         self.problem.save_quiz_illustrations(
             result_dir,
             filename_prefix,
-            quizzes[:, 1 : 1 + self.prompt_len],
-            quizzes[:, 2 + self.prompt_len :],
+            quizzes[:, : self.prompt_len],
+            quizzes[:, self.prompt_len :],
             predicted_prompts,
             predicted_answers,
         )
@@ -358,9 +358,7 @@ class QuizMachine:
             if self.back_accuracy and n_a2p.any():
                 # accuracy of B->A*->B*=B instead of B->A*=A
                 back_input = self.p_a_flip(result[n_a2p])
-                back_input[:, 2 + self.prompt_len :] = input[
-                    n_a2p, 1 : 1 + self.answer_len
-                ]
+                back_input[:, 1 + self.prompt_len :] = input[n_a2p, 1 : self.answer_len]
                 _, correct[n_a2p] = compute_accuracy(back_input)
 
             if log_prefix is not None: