X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=problems.py;h=819715e1b5b1bab6af7207f8656f6aaefb8408f0;hb=cb7001fcd7a75eaeaca9ae66fce37e372acf8cc1;hp=ef48162e53d5a7aef4ddefb0d9f37f4cccbc635f;hpb=0d4b3fe3ccd16cd72fc96fe12c85996f35233c5e;p=picoclvr.git diff --git a/problems.py b/problems.py index ef48162..819715e 100755 --- a/problems.py +++ b/problems.py @@ -52,6 +52,7 @@ class ProblemDegradation(Problem): def compute_nb_correct(self, input, ar_mask, result): nb_total = result.size(0) nb_correct = 0 + e=result.new_zeros(self.nb_state_tokens) for seq in result: states = list(seq.split(self.nb_state_tokens)) @@ -60,14 +61,14 @@ class ProblemDegradation(Problem): d = states[0] j=d.sort(descending=True).indices[0] - e=d.new_zeros(d.size()) + e.zero_() e[j]=self.value_max if (d-e).abs().sum() == 0: nb_errors = 0 for k in range(len(states)-1): d=states[k]-states[k+1] j=d.sort(descending=True).indices[0] - e=d.new_zeros(d.size()) + e.zero_() e[j]=d[j] e[(j+1)%e.size(0)]=-d[j]//2 e[(j-1)%e.size(0)]=-d[j]//2