if args.problem == "sky":
problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=3)
elif args.problem == "wireworld":
- problem = wireworld.Wireworld(height=8, width=10, nb_iterations=4)
+ problem = wireworld.Wireworld(height=8, width=10, nb_iterations=2, speed=5)
else:
raise ValueError
return self.token_forward, self.token_backward
def generate_frame_sequences(self, nb):
+ result = []
+ N = 100
+ for _ in tqdm.tqdm(
+ range(0, nb + N, N), dynamic_ncols=True, desc="world generation"
+ ):
+ result.append(self.generate_frame_sequences_hard(100))
+ return torch.cat(result, dim=0)[:nb]
+
+ def generate_frame_sequences_hard(self, nb):
frame_sequences = []
result = torch.full(
i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
result = result[i]
- print(f"{result.size(0)=} {nb=}")
+ # print(f"{result.size(0)=} {nb=}")
if result.size(0) < nb:
# print(result.size(0))
if __name__ == "__main__":
import time
- wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=5)
+ wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
start_time = time.perf_counter()
frame_sequences = wireworld.generate_frame_sequences(nb=96)