projects
/
picoclvr.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
0d4b3fe
)
Update.
author
François Fleuret
<francois@fleuret.org>
Sat, 21 Oct 2023 16:04:38 +0000
(18:04 +0200)
committer
François Fleuret
<francois@fleuret.org>
Sat, 21 Oct 2023 16:04:38 +0000
(18:04 +0200)
problems.py
patch
|
blob
|
history
diff --git
a/problems.py
b/problems.py
index
ef48162
..
819715e
100755
(executable)
--- a/
problems.py
+++ b/
problems.py
@@
-52,6
+52,7
@@
class ProblemDegradation(Problem):
def compute_nb_correct(self, input, ar_mask, result):
nb_total = result.size(0)
nb_correct = 0
def compute_nb_correct(self, input, ar_mask, result):
nb_total = result.size(0)
nb_correct = 0
+ e=result.new_zeros(self.nb_state_tokens)
for seq in result:
states = list(seq.split(self.nb_state_tokens))
for seq in result:
states = list(seq.split(self.nb_state_tokens))
@@
-60,14
+61,14
@@
class ProblemDegradation(Problem):
d = states[0]
j=d.sort(descending=True).indices[0]
d = states[0]
j=d.sort(descending=True).indices[0]
- e
=d.new_zeros(d.size()
)
+ e
.zero_(
)
e[j]=self.value_max
if (d-e).abs().sum() == 0:
nb_errors = 0
for k in range(len(states)-1):
d=states[k]-states[k+1]
j=d.sort(descending=True).indices[0]
e[j]=self.value_max
if (d-e).abs().sum() == 0:
nb_errors = 0
for k in range(len(states)-1):
d=states[k]-states[k+1]
j=d.sort(descending=True).indices[0]
- e
=d.new_zeros(d.size()
)
+ e
.zero_(
)
e[j]=d[j]
e[(j+1)%e.size(0)]=-d[j]//2
e[(j-1)%e.size(0)]=-d[j]//2
e[j]=d[j]
e[(j+1)%e.size(0)]=-d[j]//2
e[(j-1)%e.size(0)]=-d[j]//2