From 11ea129f7f4b92ea3cd6aaf68bb91de911682297 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 1 Aug 2024 12:08:15 +0200 Subject: [PATCH] Update. --- main.py | 10 +++++----- quiz_machine.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 72b2b26..526da6f 100755 --- a/main.py +++ b/main.py @@ -441,12 +441,12 @@ def one_epoch(model, quiz_machine, local_device=main_device): run_tests(model, quiz_machine) - threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values - threshold = threshold[threshold.size(0) // 2] + # threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values + # threshold = threshold[threshold.size(0) // 2] - model.hard_w_quizzes = torch.cat( - [x[l >= threshold] for x, l in hard_w_quizzes], dim=0 - ) + # model.hard_w_quizzes = torch.cat( + # [x[l >= threshold] for x, l in hard_w_quizzes], dim=0 + # ) model.to(main_device) diff --git a/quiz_machine.py b/quiz_machine.py index a042431..b7c3b09 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -228,7 +228,7 @@ class QuizMachine: nb = 0 # We consider all the configurations that we train for - for struct, mask, noise_mask in self.understood_structures: + for struct, mask, _ in self.understood_structures: i = self.problem.indices_select(quizzes=input, struct=struct) nb += i.long().sum() result[i], correct[i] = self.predict( -- 2.20.1