nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
):
worlds = torch.randint(nb_colors, (nb, height, width), device=device)
- nb_prior_visits = torch.zeros(nb, height, width, device=device)
+ world_prior_visits = torch.zeros(nb, height, width, device=device)
# nb x 2
snake_position = torch.cat(
snake_direction = snake_next_direction[i, j]
sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
- sequences_prior_visits[:, 2 * l] = nb_prior_visits[
+ sequences_prior_visits[:, 2 * l] = world_prior_visits[
i, snake_position[:, 0], snake_position[:, 1]
]
if l < prompt_length:
- nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
+ world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
sequences[:, 2 * l + 1] = snake_direction
# nb x 2
snake_position = snake_next_position[i, j]
- return sequences, sequences_prior_visits
+ return sequences, sequences_prior_visits, worlds, world_prior_visits
# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
# print(f'@2 {i=} {j=}')
-######################################################################
-
-if __name__ == "__main__":
- for n in range(16):
- descr = generate(nb=1, height=12, width=16)
-
- print(nb_properties(descr, height=12, width=16))
-
- with open(f"picoclvr_example_{n:02d}.txt", "w") as f:
- for d in descr:
- f.write(f"{d}\n\n")
+def seq2str(seq):
+ return "".join(["NESW123456789"[i] for i in seq])
- img = descr2img(descr, height=12, width=16)
- if img.size(0) == 1:
- img = F.pad(img, (1, 1, 1, 1), value=64)
- torchvision.utils.save_image(
- img / 255.0,
- f"picoclvr_example_{n:02d}.png",
- padding=1,
- nrow=4,
- pad_value=0.8,
- )
+######################################################################
- import time
+if __name__ == "__main__":
+ train_input, train_prior_visits, _, _ = generate_sequences(
+ nb=20,
+ height=9,
+ width=12,
+ nb_colors=5,
+ length=50,
+ prompt_length=100,
+ )
- start_time = time.perf_counter()
- descr = generate(nb=1000, height=12, width=16)
- end_time = time.perf_counter()
- print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
+ print([seq2str(s) for s in train_input])
######################################################################