From 1e7c259b1dd038a0f45dba96e872cd1121f38f96 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 26 Jul 2022 12:37:41 +0200 Subject: [PATCH] Update. --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index c810eef..4a332b8 100755 --- a/main.py +++ b/main.py @@ -498,7 +498,7 @@ for k in range(nb_epochs_finished, nb_epochs): for input in task.batches(split = 'test'): input = input.to(device) output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) -- 2.39.5