From a86dff174205c38d8e90d0d89ea399a6afb36359 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 11 Jul 2024 17:52:40 +0200 Subject: [PATCH] Update. --- main.py | 2 ++ quiz_machine.py | 28 +++++++++++++++++----------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/main.py b/main.py index 4cf4d59..a7338c7 100755 --- a/main.py +++ b/main.py @@ -18,6 +18,8 @@ import sky, grids, quiz_machine import threading +import torch.multiprocessing as mp + # world quizzes vs. culture quizzes ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index ae14614..8ab5696 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -424,17 +424,23 @@ class QuizMachine: ) for model in models: - for input, l in zip( - c_quizzes.split(self.batch_size), logproba.split(self.batch_size) - ): - input = input.to(self.device) - ar_mask = self.make_ar_mask(input) - output = model(mygpt.BracketedSequence(input)).x - ce = ( - F.cross_entropy(output.transpose(1, 2), input, reduction="none") - * ar_mask - ) - l[:, model.id] = -ce.sum(dim=-1) + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, l in zip( + c_quizzes.split(self.batch_size), logproba.split(self.batch_size) + ): + input = input.to(self.device) + ar_mask = self.make_ar_mask(input) + output = model(mygpt.BracketedSequence(input)).x + ce = ( + F.cross_entropy(output.transpose(1, 2), input, reduction="none") + * ar_mask + ) + l[:, model.id] = -ce.sum(dim=-1) + + model.train(t) return logproba.to("cpu") -- 2.39.5