Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 10:08:15 +0000 (12:08 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 1 Aug 2024 10:08:15 +0000 (12:08 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 72b2b26..526da6f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -441,12 +441,12 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
     run_tests(model, quiz_machine)
 
-    threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values
-    threshold = threshold[threshold.size(0) // 2]
+    threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values
+    threshold = threshold[threshold.size(0) // 2]
 
-    model.hard_w_quizzes = torch.cat(
-        [x[l >= threshold] for x, l in hard_w_quizzes], dim=0
-    )
+    model.hard_w_quizzes = torch.cat(
+    # [x[l >= threshold] for x, l in hard_w_quizzes], dim=0
+    )
 
     model.to(main_device)
 
index a042431..b7c3b09 100755 (executable)
@@ -228,7 +228,7 @@ class QuizMachine:
         nb = 0
 
         # We consider all the configurations that we train for
-        for struct, mask, noise_mask in self.understood_structures:
+        for struct, mask, _ in self.understood_structures:
             i = self.problem.indices_select(quizzes=input, struct=struct)
             nb += i.long().sum()
             result[i], correct[i] = self.predict(