break
result = torch.zeros(
- self.nb_iterations, self.height, self.width, dtype=torch.int64
+ self.nb_iterations * self.speed,
+ self.height,
+ self.width,
+ dtype=torch.int64,
)
- for l in range(self.nb_iterations):
+ for l in range(self.nb_iterations * self.speed):
fine = collision_okay()
for n in range(self.nb_birds):
c = col[n]
if fine:
break
- frame_sequences.append(result)
+ frame_sequences.append(
+ result[
+ torch.arange(self.nb_iterations, device=result.device) * self.speed
+ ]
+ )
return frame_sequences
"_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
)
- def __init__(self, height=6, width=8, nb_objects=2, nb_walls=2, nb_iterations=4):
+ def __init__(
+ self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4
+ ):
self.height = height
self.width = width
self.nb_objects = nb_objects
self.nb_walls = nb_walls
+ self.speed = speed
self.nb_iterations = nb_iterations
def direction_tokens(self):
# tail->conductor
# conductor->head if 1 or 2 head in the neighborhood, or remains conductor
- for l in range(self.nb_iterations - 1):
+ for l in range(self.nb_iterations * self.speed - 1):
nb_head_neighbors = (
F.conv2d(
input=(result[:, l] == self.token_head).float()[:, None, :, :],
i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
- result = result[i]
+ result = result[
+ torch.arange(self.nb_iterations, device=result.device) * self.speed
+ ]
if result.size(0) < nb:
# print(result.size(0))
[result, self.generate_frame_sequences(nb - result.size(0))], dim=0
)
- return result
+ return result[:nb]
def generate_token_sequences(self, nb):
frame_sequences = self.generate_frame_sequences(nb)
if __name__ == "__main__":
import time
- wireworld = Wireworld(height=10, width=15, nb_iterations=4)
+ wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=1)
start_time = time.perf_counter()
frame_sequences = wireworld.generate_frame_sequences(nb=96)
# print(wireworld.seq2str(seq[:4]))
- for t in range(frame_sequences.size(1)):
- img = wireworld.seq2img(frame_sequences[:, t])
- torchvision.utils.save_image(
- img.float() / 255.0,
- f"/tmp/frame_{t:03d}.png",
- nrow=8,
- padding=6,
- pad_value=0,
- )
+ # for t in range(frame_sequences.size(1)):
+ # img = wireworld.seq2img(frame_sequences[:, t])
+ # torchvision.utils.save_image(
+ # img.float() / 255.0,
+ # f"/tmp/frame_{t:03d}.png",
+ # nrow=8,
+ # padding=6,
+ # pad_value=0,
+ # )
# m = (torch.rand(seq.size()) < 0.05).long()
# seq = (1 - m) * seq + m * 23
+ token_sequences = wireworld.generate_token_sequences(32)
+ wireworld.save_quizzes(token_sequences, "/tmp", "seq")
# img = wireworld.seq2img(frame_sequences[:60])
# torchvision.utils.save_image(