From db3521b1fc580aad33970c5fad9dbf7d3721ac68 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 27 Jul 2024 08:55:19 +0200 Subject: [PATCH] Update. --- main.py | 10 ++++++++-- quiz_machine.py | 5 ++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index be5e1bd..1b68eca 100755 --- a/main.py +++ b/main.py @@ -550,7 +550,14 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 # We discard the trivial ones, according to a criterion # specific to the world quizzes (e.g. B=f(B)) - c_quizzes = c_quizzes[quiz_machine.problem.trivial(c_quizzes) == False] + rejected = [] + + to_keep == quiz_machine.problem.trivial(c_quizzes) == False + + if not to_keep.all(): + rejected.append(c_quizzes[to_keep == False]) + + c_quizzes = c_quizzes[to_keep] # We go through nb_rounds rounds and keep only quizzes on # which @@ -564,7 +571,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 number_correct_responses = 0 nb_remaining = [c_quizzes.size(0)] - rejected = [] for r in range(args.nb_rounds): if c_quizzes.size(0) == 0: diff --git a/quiz_machine.py b/quiz_machine.py index 13c157e..083b50e 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -102,6 +102,7 @@ class QuizMachine: ###################################################################### def autoregression( + self, model, input, ar_mask, @@ -139,7 +140,7 @@ class QuizMachine: ar_mask=ar_mask, seq_logproba=seq_logproba, logit_transformer=logit_transformer, - deterministic_synthesis=deterministic_synthesis, + deterministic_synthesis=False, ) model.train(t) @@ -193,6 +194,8 @@ class QuizMachine: ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask) result = quizzes * (1 - ar_mask) + seq_logproba = torch.empty(quizzes.size(0), device=self.device) + self.autoregression( model=model, input=result, -- 2.39.5