From ce7b4d55087703e6f67da2c26c9bed22f58113eb Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 13 Aug 2024 00:07:15 +0200 Subject: [PATCH] Update. --- main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/main.py b/main.py index 0a79323..bd46948 100755 --- a/main.py +++ b/main.py @@ -540,8 +540,7 @@ def save_additional_results(model, models, science_w_quizzes): for model in models ] - seq_logprobas = torch.cat([x[None, :] for x in l]) - + seq_logprobas = torch.cat([x[:, None] for x in l], dim=1) probas = seq_logprobas.exp() comments = [] -- 2.39.5