def generate_frame_sequences_hard(self, nb):
frame_sequences = []
+ nb_frames = (self.nb_iterations - 1) * self.speed + 1
result = torch.full(
- (nb * 4, self.nb_iterations * self.speed, self.height, self.width),
+ (nb * 4, nb_frames, self.height, self.width),
self.token_empty,
)
result[n, 0, i + vi, j + vj] = self.token_tail
break
- if torch.rand(1) < 0.75:
- break
+ # if torch.rand(1) < 0.75:
+ break
weight = torch.full((1, 1, 3, 3), 1.0)
# tail->conductor
# conductor->head if 1 or 2 head in the neighborhood, or remains conductor
- for l in range(self.nb_iterations * self.speed - 1):
+ nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1)
+ valid = nb_heads > 0
+
+ for l in range(nb_frames - 1):
nb_head_neighbors = (
F.conv2d(
input=(result[:, l] == self.token_head).float()[:, None, :, :],
+ (1 - mask_1_or_2_heads) * self.token_conductor
)
)
+ pred_nb_heads = nb_heads
+ nb_heads = (
+ (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1)
+ )
+ valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads))
+
+ result = result[valid]
result = result[
:, torch.arange(self.nb_iterations, device=result.device) * self.speed