Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 17 Jul 2024 03:24:56 +0000 (05:24 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 17 Jul 2024 03:24:56 +0000 (05:24 +0200)
main.py

diff --git a/main.py b/main.py
index 76db5e2..74a3cfb 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -367,8 +367,8 @@ def one_epoch(model, quiz_machine, local_device=main_device):
         acc_train_loss += loss.item() * input.size(0)
 
         loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
-        n_forward = input[:, 0] == self.token_forward
-        to_store = from_w & n_forward
+        n_forward = input[:, 0] == quiz_machine.token_forward
+        to_store = from_w & n_forward.to("cpu")
         if to_store.any():
             hard_w_quizzes.append(
                 (input[to_store].to("cpu"), loss_per_samples[to_store].to("cpu"))