X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=quiz_machine.py;h=eab41dc2a52275ccdae06326ef658503e0b1dd4e;hb=a10a4da02cdf8ea87a63e8d74381551e60cf7bd8;hp=88fd9f1dfea42ebfdbd99f290243b66d45c8f639;hpb=7b716a85786247b292ee71a635c98a18c66b421d;p=culture.git diff --git a/quiz_machine.py b/quiz_machine.py index 88fd9f1..eab41dc 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -373,7 +373,7 @@ class QuizMachine: return result, correct - compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") + # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") test_result, test_correct = compute_accuracy( model.test_w_quizzes[:nmax], log_prefix="test" @@ -412,6 +412,12 @@ class QuizMachine: else: self.test_c_quizzes.append(new_c_quizzes.to("cpu")) + def save_c_quizzes(self, filename): + torch.save((self.train_c_quizzes, self.test_c_quizzes), filename) + + def load_c_quizzes(self, filename): + self.train_c_quizzes, self.test_c_quizzes = torch.load(filename) + ###################################################################### def logproba_of_solutions(self, models, c_quizzes):