X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=snake.py;h=8a16f9f6bef6228def574f05730a19f1210dd875;hb=fdc61b7e50e029aac58b10f377acdce549532f84;hp=eb46a076e5e548111c9ff100e73c1f904cc4d9d9;hpb=cf7fcbb7a946c4d1f4d29a28e0eb04940d3b0f76;p=picoclvr.git diff --git a/snake.py b/snake.py index eb46a07..8a16f9f 100755 --- a/snake.py +++ b/snake.py @@ -13,7 +13,7 @@ def generate_sequences( nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu") ): worlds = torch.randint(nb_colors, (nb, height, width), device=device) - nb_prior_visits = torch.zeros(nb, height, width, device=device) + world_prior_visits = torch.zeros(nb, height, width, device=device) # nb x 2 snake_position = torch.cat( @@ -70,17 +70,17 @@ def generate_sequences( snake_direction = snake_next_direction[i, j] sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4 - sequences_prior_visits[:, 2 * l] = nb_prior_visits[ + sequences_prior_visits[:, 2 * l] = world_prior_visits[ i, snake_position[:, 0], snake_position[:, 1] ] if l < prompt_length: - nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1 + world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1 sequences[:, 2 * l + 1] = snake_direction # nb x 2 snake_position = snake_next_position[i, j] - return sequences, sequences_prior_visits + return sequences, sequences_prior_visits, worlds, world_prior_visits # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20) @@ -111,35 +111,22 @@ def solver(input, ar_mask): # print(f'@2 {i=} {j=}') -###################################################################### - -if __name__ == "__main__": - for n in range(16): - descr = generate(nb=1, height=12, width=16) - - print(nb_properties(descr, height=12, width=16)) - - with open(f"picoclvr_example_{n:02d}.txt", "w") as f: - for d in descr: - f.write(f"{d}\n\n") +def seq2str(seq): + return "".join(["NESW123456789"[i] for i in seq]) - img = descr2img(descr, height=12, width=16) - if img.size(0) == 1: - img = F.pad(img, (1, 1, 1, 1), value=64) - torchvision.utils.save_image( - img / 255.0, - f"picoclvr_example_{n:02d}.png", - padding=1, - nrow=4, - pad_value=0.8, - ) +###################################################################### - import time +if __name__ == "__main__": + train_input, train_prior_visits, _, _ = generate_sequences( + nb=20, + height=9, + width=12, + nb_colors=5, + length=50, + prompt_length=100, + ) - start_time = time.perf_counter() - descr = generate(nb=1000, height=12, width=16) - end_time = time.perf_counter() - print(f"{len(descr) / (end_time - start_time):.02f} samples per second") + print([seq2str(s) for s in train_input]) ######################################################################