3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import torch, torchvision
9 import torch.nn.functional as F
12 def generate_sequences(
13 nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
15 worlds = torch.randint(nb_colors, (nb, height, width), device=device)
16 world_prior_visits = torch.zeros(nb, height, width, device=device)
19 snake_position = torch.cat(
21 torch.randint(height, (nb, 1), device=device),
22 torch.randint(width, (nb, 1), device=device),
26 snake_direction = torch.randint(4, (nb,), device=device)
27 sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
28 sequences_prior_visits = torch.zeros(
29 nb, 2 * length, device=device, dtype=torch.int64
31 i = torch.arange(nb, device=device) # [:,None]
33 for l in range(length):
35 snake_next_direction = torch.cat(
37 (snake_direction[:, None] - 1) % 4,
38 snake_direction[:, None],
39 (snake_direction[:, None] + 1) % 4,
45 vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
46 vw = snake_next_direction % 2 * (snake_next_direction - 2)
49 snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
50 snake_next_position = snake_position[:, None, :] + snake_next_speed
53 val = torch.logical_and(
55 snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
58 snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
62 # The multiplicative factors bias toward moving forward
65 * torch.tensor([[1.0, 2.0, 1.0]], device=device)
70 snake_direction = snake_next_direction[i, j]
72 sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
73 sequences_prior_visits[:, 2 * l] = world_prior_visits[
74 i, snake_position[:, 0], snake_position[:, 1]
77 world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
78 sequences[:, 2 * l + 1] = snake_direction
81 snake_position = snake_next_position[i, j]
83 return sequences, sequences_prior_visits, worlds, world_prior_visits
86 # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
90 def solver(input, ar_mask):
91 for n in range(input.size(0)):
92 i, j, memory = 0, 0, {}
95 for l in range(input.size(1) // 2):
96 if ar_mask[n, 2 * l] == 1:
97 if memory.get((i, j)) is None:
100 input[n, 2 * l] = memory[(i, j)]
102 # print(f'@3 {memory=}')
103 if memory.get((i, j)) is None:
104 memory[(i, j)] = input[n, 2 * l]
106 assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}"
107 # print(f'@1 {i=} {j=}')
108 d = input[n, 2 * l + 1].item()
109 i += (d + 1) % 2 * (d - 1)
111 # print(f'@2 {i=} {j=}')
115 return "".join(["NESW123456789"[i] for i in seq])
118 ######################################################################
120 if __name__ == "__main__":
121 train_input, train_prior_visits, _, _ = generate_sequences(
130 print([seq2str(s) for s in train_input])
132 ######################################################################