X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=wireworld.py;h=8257cadfc6599acb2bc0f4cb72b6245b762897c2;hb=3b41e2797fc340fd11cb35015b57c3cae1e8447b;hp=65b12adda45e27ef20d3b4d646f092f663b37356;hpb=f9aee903b896ae73dafe0cce1dcc40d8a39accb0;p=culture.git diff --git a/wireworld.py b/wireworld.py index 65b12ad..8257cad 100755 --- a/wireworld.py +++ b/wireworld.py @@ -62,9 +62,10 @@ class Wireworld(problem.Problem): def generate_frame_sequences_hard(self, nb): frame_sequences = [] + nb_frames = (self.nb_iterations - 1) * self.speed + 1 result = torch.full( - (nb * 4, self.nb_iterations * self.speed, self.height, self.width), + (nb * 4, nb_frames, self.height, self.width), self.token_empty, ) @@ -116,8 +117,8 @@ class Wireworld(problem.Problem): result[n, 0, i + vi, j + vj] = self.token_tail break - if torch.rand(1) < 0.75: - break + # if torch.rand(1) < 0.75: + break weight = torch.full((1, 1, 3, 3), 1.0) @@ -130,7 +131,10 @@ class Wireworld(problem.Problem): # tail->conductor # conductor->head if 1 or 2 head in the neighborhood, or remains conductor - for l in range(self.nb_iterations * self.speed - 1): + nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1) + valid = nb_heads > 0 + + for l in range(nb_frames - 1): nb_head_neighbors = ( F.conv2d( input=(result[:, l] == self.token_head).float()[:, None, :, :], @@ -153,6 +157,13 @@ class Wireworld(problem.Problem): + (1 - mask_1_or_2_heads) * self.token_conductor ) ) + pred_nb_heads = nb_heads + nb_heads = ( + (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1) + ) + valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads)) + + result = result[valid] result = result[ :, torch.arange(self.nb_iterations, device=result.device) * self.speed