From cb7001fcd7a75eaeaca9ae66fce37e372acf8cc1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 21 Oct 2023 18:04:38 +0200 Subject: [PATCH] Update. --- problems.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/problems.py b/problems.py index ef48162..819715e 100755 --- a/problems.py +++ b/problems.py @@ -52,6 +52,7 @@ class ProblemDegradation(Problem): def compute_nb_correct(self, input, ar_mask, result): nb_total = result.size(0) nb_correct = 0 + e=result.new_zeros(self.nb_state_tokens) for seq in result: states = list(seq.split(self.nb_state_tokens)) @@ -60,14 +61,14 @@ class ProblemDegradation(Problem): d = states[0] j=d.sort(descending=True).indices[0] - e=d.new_zeros(d.size()) + e.zero_() e[j]=self.value_max if (d-e).abs().sum() == 0: nb_errors = 0 for k in range(len(states)-1): d=states[k]-states[k+1] j=d.sort(descending=True).indices[0] - e=d.new_zeros(d.size()) + e.zero_() e[j]=d[j] e[(j+1)%e.size(0)]=-d[j]//2 e[(j-1)%e.size(0)]=-d[j]//2 -- 2.39.5