X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=problems.py;h=7aa59bea856c205d232ec1bf49c80f18ef20bed1;hb=16e7952b7cc32ca21498fa3a12fb79f679ea8c21;hp=ef48162e53d5a7aef4ddefb0d9f37f4cccbc635f;hpb=0d4b3fe3ccd16cd72fc96fe12c85996f35233c5e;p=picoclvr.git diff --git a/problems.py b/problems.py index ef48162..7aa59be 100755 --- a/problems.py +++ b/problems.py @@ -26,7 +26,7 @@ class Problem: class ProblemDegradation(Problem): - def __init__(self, nb_state_tokens=5, nb_time_steps=5, value_max=25, hard=False): + def __init__(self, nb_state_tokens=7, nb_time_steps=10, value_max=100, hard=False): self.nb_state_tokens = nb_state_tokens self.nb_time_steps = nb_time_steps self.value_max = value_max @@ -52,6 +52,7 @@ class ProblemDegradation(Problem): def compute_nb_correct(self, input, ar_mask, result): nb_total = result.size(0) nb_correct = 0 + e=result.new_zeros(self.nb_state_tokens) for seq in result: states = list(seq.split(self.nb_state_tokens)) @@ -60,14 +61,14 @@ class ProblemDegradation(Problem): d = states[0] j=d.sort(descending=True).indices[0] - e=d.new_zeros(d.size()) + e.zero_() e[j]=self.value_max if (d-e).abs().sum() == 0: nb_errors = 0 for k in range(len(states)-1): d=states[k]-states[k+1] j=d.sort(descending=True).indices[0] - e=d.new_zeros(d.size()) + e.zero_() e[j]=d[j] e[(j+1)%e.size(0)]=-d[j]//2 e[(j-1)%e.size(0)]=-d[j]//2 @@ -262,5 +263,6 @@ class ProblemAddition(Problem): if __name__ == "__main__": p = ProblemDegradation(hard=False) s, m = p.generate_sequences(10000) - print(p.seq2str(s[0])) + for x in s[:100]: + print(p.seq2str(x)) print(p.compute_nb_correct(None, None, s))