Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jul 2024 06:55:19 +0000 (08:55 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 27 Jul 2024 06:55:19 +0000 (08:55 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index be5e1bd..1b68eca 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -550,7 +550,14 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         # We discard the trivial ones, according to a criterion
         # specific to the world quizzes (e.g. B=f(B))
 
-        c_quizzes = c_quizzes[quiz_machine.problem.trivial(c_quizzes) == False]
+        rejected = []
+
+        to_keep == quiz_machine.problem.trivial(c_quizzes) == False
+
+        if not to_keep.all():
+            rejected.append(c_quizzes[to_keep == False])
+
+        c_quizzes = c_quizzes[to_keep]
 
         # We go through nb_rounds rounds and keep only quizzes on
         # which
@@ -564,7 +571,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
         number_correct_responses = 0
         nb_remaining = [c_quizzes.size(0)]
-        rejected = []
 
         for r in range(args.nb_rounds):
             if c_quizzes.size(0) == 0:
index 13c157e..083b50e 100755 (executable)
@@ -102,6 +102,7 @@ class QuizMachine:
     ######################################################################
 
     def autoregression(
+        self,
         model,
         input,
         ar_mask,
@@ -139,7 +140,7 @@ class QuizMachine:
                     ar_mask=ar_mask,
                     seq_logproba=seq_logproba,
                     logit_transformer=logit_transformer,
-                    deterministic_synthesis=deterministic_synthesis,
+                    deterministic_synthesis=False,
                 )
 
             model.train(t)
@@ -193,6 +194,8 @@ class QuizMachine:
         ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
         result = quizzes * (1 - ar_mask)
 
+        seq_logproba = torch.empty(quizzes.size(0), device=self.device)
+
         self.autoregression(
             model=model,
             input=result,