Merge branch 'dev'
[culture.git] / wireworld.py
index 65b12ad..8257cad 100755 (executable)
@@ -62,9 +62,10 @@ class Wireworld(problem.Problem):
 
     def generate_frame_sequences_hard(self, nb):
         frame_sequences = []
+        nb_frames = (self.nb_iterations - 1) * self.speed + 1
 
         result = torch.full(
-            (nb * 4, self.nb_iterations * self.speed, self.height, self.width),
+            (nb * 4, nb_frames, self.height, self.width),
             self.token_empty,
         )
 
@@ -116,8 +117,8 @@ class Wireworld(problem.Problem):
                         result[n, 0, i + vi, j + vj] = self.token_tail
                         break
 
-                if torch.rand(1) < 0.75:
-                    break
+                if torch.rand(1) < 0.75:
+                break
 
         weight = torch.full((1, 1, 3, 3), 1.0)
 
@@ -130,7 +131,10 @@ class Wireworld(problem.Problem):
         # tail->conductor
         # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
 
-        for l in range(self.nb_iterations * self.speed - 1):
+        nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1)
+        valid = nb_heads > 0
+
+        for l in range(nb_frames - 1):
             nb_head_neighbors = (
                 F.conv2d(
                     input=(result[:, l] == self.token_head).float()[:, None, :, :],
@@ -153,6 +157,13 @@ class Wireworld(problem.Problem):
                     + (1 - mask_1_or_2_heads) * self.token_conductor
                 )
             )
+            pred_nb_heads = nb_heads
+            nb_heads = (
+                (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1)
+            )
+            valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads))
+
+        result = result[valid]
 
         result = result[
             :, torch.arange(self.nb_iterations, device=result.device) * self.speed