Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 30 Jul 2024 10:20:13 +0000 (12:20 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 30 Jul 2024 10:20:13 +0000 (12:20 +0200)
grids.py
main.py
quiz_machine.py

index ebc2b0e..8d274ad 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -342,17 +342,15 @@ class Grids(problem.Problem):
     ):
         quizzes = quizzes.to("cpu")
 
-        to_reconfigure = [result]
+        to_reconfigure = [quizzes]
         if predicted_parts is not None:
             to_reconfigure.append(predicted_parts)
         if correct_parts is not None:
             to_reconfigure.append(correct_parts)
 
-        to_reconfigure = self.problem.reconfigure(
-            to_reconfigure, ("A", "f_A", "B", "f_B")
-        )
+        to_reconfigure = self.reconfigure(to_reconfigure, ("A", "f_A", "B", "f_B"))
 
-        result = to_reconfigure.pop(0)
+        quizzes = to_reconfigure.pop(0)
         if predicted_parts is not None:
             predicted_parts = to_reconfigure.pop(0)
         if correct_parts is not None:
diff --git a/main.py b/main.py
index 6f543a0..455aa1c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -493,7 +493,9 @@ def save_additional_results(models, science_w_quizzes):
                 mask=mask,
             )
 
-            predicted_parts = torch.tensor(mask, device=correct.device)[None, :]
+            predicted_parts = torch.tensor(mask, device=correct.device)[None, :].expand(
+                correct.size(0), -1
+            )
             correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long()
 
             nb_correct = (correct == 1).long().sum()
index 1ff23ed..90879ce 100755 (executable)
@@ -172,7 +172,7 @@ class QuizMachine:
                 from_w = torch.full((quizzes.size(0),), True, device=quizzes.device)
 
             self.randomize_configuations_inplace(
-                quizzes, structs=[s for s in self.understood_structures]
+                quizzes, structs=[s for s, m in self.understood_structures]
             )
 
             i = torch.randperm(quizzes.size(0), device=quizzes.device)
@@ -182,7 +182,7 @@ class QuizMachine:
     ######################################################################
 
     def make_ar_mask(self, quizzes, struct, mask):
-        assert struct in [s for s in self.understood_structures]
+        assert struct in [s for s, m in self.understood_structures]
         return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask)
 
     ######################################################################
@@ -296,7 +296,7 @@ class QuizMachine:
         )
 
         self.randomize_configuations_inplace(
-            model.train_w_quizzes, structs=[s for s in self.understood_structures]
+            model.train_w_quizzes, structs=[s for s, m in self.understood_structures]
         )
 
     ######################################################################