"_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
)
- def __init__(self, height=6, width=8, nb_objects=2, nb_walls=2, nb_iterations=4):
+ def __init__(
+ self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4
+ ):
self.height = height
self.width = width
self.nb_objects = nb_objects
self.nb_walls = nb_walls
+ self.speed = speed
self.nb_iterations = nb_iterations
def direction_tokens(self):
return self.token_forward, self.token_backward
def generate_frame_sequences(self, nb):
+ result = []
+ N = 100
+ for _ in tqdm.tqdm(
+ range(0, nb + N, N), dynamic_ncols=True, desc="world generation"
+ ):
+ result.append(self.generate_frame_sequences_hard(100))
+ return torch.cat(result, dim=0)[:nb]
+
+ def generate_frame_sequences_hard(self, nb):
frame_sequences = []
+ nb_frames = (self.nb_iterations - 1) * self.speed + 1
result = torch.full(
- (nb * 4, self.nb_iterations, self.height, self.width), self.token_empty
+ (nb * 4, nb_frames, self.height, self.width),
+ self.token_empty,
)
for n in range(result.size(0)):
while True:
if i < 0 or i >= self.height or j < 0 or j >= self.width:
break
+ o = 0
+ if i > 0:
+ o += (result[n, 0, i - 1, j] == self.token_conductor).long()
+ if i < self.height - 1:
+ o += (result[n, 0, i + 1, j] == self.token_conductor).long()
+ if j > 0:
+ o += (result[n, 0, i, j - 1] == self.token_conductor).long()
+ if j < self.width - 1:
+ o += (result[n, 0, i, j + 1] == self.token_conductor).long()
+ if o > 1:
+ break
result[n, 0, i, j] = self.token_conductor
i += vi
j += vj
- if torch.rand(1) < 0.5:
+ if (
+ result[n, 0] == self.token_conductor
+ ).long().sum() > self.width and torch.rand(1) < 0.5:
break
+ while True:
+ for _ in range(self.height * self.width):
+ i = torch.randint(self.height, (1,))
+ j = torch.randint(self.width, (1,))
+ v = torch.randint(2, (2,))
+ vi = v[0] * (v[1] * 2 - 1)
+ vj = (1 - v[0]) * (v[1] * 2 - 1)
+ if (
+ i + vi >= 0
+ and i + vi < self.height
+ and j + vj >= 0
+ and j + vj < self.width
+ and result[n, 0, i, j] == self.token_conductor
+ and result[n, 0, i + vi, j + vj] == self.token_conductor
+ ):
+ result[n, 0, i, j] = self.token_head
+ result[n, 0, i + vi, j + vj] = self.token_tail
+ break
+
+ # if torch.rand(1) < 0.75:
+ break
+
weight = torch.full((1, 1, 3, 3), 1.0)
mask = (torch.rand(result[:, 0].size()) < 0.01).long()
# tail->conductor
# conductor->head if 1 or 2 head in the neighborhood, or remains conductor
- for l in range(self.nb_iterations - 1):
+ nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1)
+ valid = nb_heads > 0
+
+ for l in range(nb_frames - 1):
nb_head_neighbors = (
F.conv2d(
input=(result[:, l] == self.token_head).float()[:, None, :, :],
+ (1 - mask_1_or_2_heads) * self.token_conductor
)
)
+ pred_nb_heads = nb_heads
+ nb_heads = (
+ (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1)
+ )
+ valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads))
- i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
+ result = result[valid]
+
+ result = result[
+ :, torch.arange(self.nb_iterations, device=result.device) * self.speed
+ ]
+ i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
result = result[i]
+ # print(f"{result.size(0)=} {nb=}")
+
if result.size(0) < nb:
# print(result.size(0))
result = torch.cat(
[result, self.generate_frame_sequences(nb - result.size(0))], dim=0
)
- return result
+ return result[:nb]
def generate_token_sequences(self, nb):
frame_sequences = self.generate_frame_sequences(nb)
if __name__ == "__main__":
import time
- wireworld = Wireworld(height=10, width=15, nb_iterations=4)
+ wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1)
start_time = time.perf_counter()
frame_sequences = wireworld.generate_frame_sequences(nb=96)
# m = (torch.rand(seq.size()) < 0.05).long()
# seq = (1 - m) * seq + m * 23
+ wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
+ token_sequences = wireworld.generate_token_sequences(32)
+ wireworld.save_quizzes(token_sequences, "/tmp", "seq")
# img = wireworld.seq2img(frame_sequences[:60])
# torchvision.utils.save_image(