Update.
[picoclvr.git] / problems.py
index ef48162..7aa59be 100755 (executable)
@@ -26,7 +26,7 @@ class Problem:
 
 
 class ProblemDegradation(Problem):
-    def __init__(self, nb_state_tokens=5, nb_time_steps=5, value_max=25, hard=False):
+    def __init__(self, nb_state_tokens=7, nb_time_steps=10, value_max=100, hard=False):
         self.nb_state_tokens = nb_state_tokens
         self.nb_time_steps = nb_time_steps
         self.value_max = value_max
@@ -52,6 +52,7 @@ class ProblemDegradation(Problem):
     def compute_nb_correct(self, input, ar_mask, result):
         nb_total = result.size(0)
         nb_correct = 0
+        e=result.new_zeros(self.nb_state_tokens)
 
         for seq in result:
             states = list(seq.split(self.nb_state_tokens))
@@ -60,14 +61,14 @@ class ProblemDegradation(Problem):
 
             d = states[0]
             j=d.sort(descending=True).indices[0]
-            e=d.new_zeros(d.size())
+            e.zero_()
             e[j]=self.value_max
             if (d-e).abs().sum() == 0:
                 nb_errors = 0
                 for k in range(len(states)-1):
                     d=states[k]-states[k+1]
                     j=d.sort(descending=True).indices[0]
-                    e=d.new_zeros(d.size())
+                    e.zero_()
                     e[j]=d[j]
                     e[(j+1)%e.size(0)]=-d[j]//2
                     e[(j-1)%e.size(0)]=-d[j]//2
@@ -262,5 +263,6 @@ class ProblemAddition(Problem):
 if __name__ == "__main__":
     p = ProblemDegradation(hard=False)
     s, m = p.generate_sequences(10000)
-    print(p.seq2str(s[0]))
+    for x in s[:100]:
+        print(p.seq2str(x))
     print(p.compute_nb_correct(None, None, s))