projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
wireworld.py
diff --git
a/wireworld.py
b/wireworld.py
index
65b12ad
..
8257cad
100755
(executable)
--- a/
wireworld.py
+++ b/
wireworld.py
@@
-62,9
+62,10
@@
class Wireworld(problem.Problem):
def generate_frame_sequences_hard(self, nb):
frame_sequences = []
def generate_frame_sequences_hard(self, nb):
frame_sequences = []
+ nb_frames = (self.nb_iterations - 1) * self.speed + 1
result = torch.full(
result = torch.full(
- (nb * 4,
self.nb_iterations * self.speed
, self.height, self.width),
+ (nb * 4,
nb_frames
, self.height, self.width),
self.token_empty,
)
self.token_empty,
)
@@
-116,8
+117,8
@@
class Wireworld(problem.Problem):
result[n, 0, i + vi, j + vj] = self.token_tail
break
result[n, 0, i + vi, j + vj] = self.token_tail
break
- if torch.rand(1) < 0.75:
-
break
+
#
if torch.rand(1) < 0.75:
+ break
weight = torch.full((1, 1, 3, 3), 1.0)
weight = torch.full((1, 1, 3, 3), 1.0)
@@
-130,7
+131,10
@@
class Wireworld(problem.Problem):
# tail->conductor
# conductor->head if 1 or 2 head in the neighborhood, or remains conductor
# tail->conductor
# conductor->head if 1 or 2 head in the neighborhood, or remains conductor
- for l in range(self.nb_iterations * self.speed - 1):
+ nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1)
+ valid = nb_heads > 0
+
+ for l in range(nb_frames - 1):
nb_head_neighbors = (
F.conv2d(
input=(result[:, l] == self.token_head).float()[:, None, :, :],
nb_head_neighbors = (
F.conv2d(
input=(result[:, l] == self.token_head).float()[:, None, :, :],
@@
-153,6
+157,13
@@
class Wireworld(problem.Problem):
+ (1 - mask_1_or_2_heads) * self.token_conductor
)
)
+ (1 - mask_1_or_2_heads) * self.token_conductor
)
)
+ pred_nb_heads = nb_heads
+ nb_heads = (
+ (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1)
+ )
+ valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads))
+
+ result = result[valid]
result = result[
:, torch.arange(self.nb_iterations, device=result.device) * self.speed
result = result[
:, torch.arange(self.nb_iterations, device=result.device) * self.speed