From: François Fleuret Date: Thu, 1 Aug 2024 10:08:15 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=11ea129f7f4b92ea3cd6aaf68bb91de911682297;p=culture.git Update. --- diff --git a/main.py b/main.py index 72b2b26..526da6f 100755 --- a/main.py +++ b/main.py @@ -441,12 +441,12 @@ def one_epoch(model, quiz_machine, local_device=main_device): run_tests(model, quiz_machine) - threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values - threshold = threshold[threshold.size(0) // 2] + # threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values + # threshold = threshold[threshold.size(0) // 2] - model.hard_w_quizzes = torch.cat( - [x[l >= threshold] for x, l in hard_w_quizzes], dim=0 - ) + # model.hard_w_quizzes = torch.cat( + # [x[l >= threshold] for x, l in hard_w_quizzes], dim=0 + # ) model.to(main_device) diff --git a/quiz_machine.py b/quiz_machine.py index a042431..b7c3b09 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -228,7 +228,7 @@ class QuizMachine: nb = 0 # We consider all the configurations that we train for - for struct, mask, noise_mask in self.understood_structures: + for struct, mask, _ in self.understood_structures: i = self.problem.indices_select(quizzes=input, struct=struct) nb += i.long().sum() result[i], correct[i] = self.predict(