Update.
[picoclvr.git] / snake.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import torch, torchvision
9 import torch.nn.functional as F
10
11
12 def generate_sequences(
13     nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
14 ):
15     worlds = torch.randint(nb_colors, (nb, height, width), device=device)
16     nb_prior_visits = torch.zeros(nb, height, width, device=device)
17
18     # nb x 2
19     snake_position = torch.cat(
20         (
21             torch.randint(height, (nb, 1), device=device),
22             torch.randint(width, (nb, 1), device=device),
23         ),
24         1,
25     )
26     snake_direction = torch.randint(4, (nb,), device=device)
27     sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
28     sequences_prior_visits = torch.zeros(
29         nb, 2 * length, device=device, dtype=torch.int64
30     )
31     i = torch.arange(nb, device=device)  # [:,None]
32
33     for l in range(length):
34         # nb x 3
35         snake_next_direction = torch.cat(
36             (
37                 (snake_direction[:, None] - 1) % 4,
38                 snake_direction[:, None],
39                 (snake_direction[:, None] + 1) % 4,
40             ),
41             1,
42         )
43
44         # nb x 3
45         vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
46         vw = snake_next_direction % 2 * (snake_next_direction - 2)
47
48         # nb x 3 x 2
49         snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
50         snake_next_position = snake_position[:, None, :] + snake_next_speed
51
52         # nb x 3
53         val = torch.logical_and(
54             torch.logical_and(
55                 snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
56             ),
57             torch.logical_and(
58                 snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
59             ),
60         ).float()
61         val = (
62             # The multiplicative factors bias toward moving forward
63             torch.rand_like(val)
64             * val
65             * torch.tensor([[1.0, 2.0, 1.0]], device=device)
66         )
67
68         # nb
69         j = val.argmax(1)
70         snake_direction = snake_next_direction[i, j]
71
72         sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
73         sequences_prior_visits[:, 2 * l] = nb_prior_visits[
74             i, snake_position[:, 0], snake_position[:, 1]
75         ]
76         if l < prompt_length:
77             nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
78         sequences[:, 2 * l + 1] = snake_direction
79
80         # nb x 2
81         snake_position = snake_next_position[i, j]
82
83     return sequences, sequences_prior_visits
84
85
86 # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
87 # exit(0)
88
89
90 def solver(input, ar_mask):
91     for n in range(input.size(0)):
92         i, j, memory = 0, 0, {}
93         # print(input[n])
94         # print(ar_mask[n])
95         for l in range(input.size(1) // 2):
96             if ar_mask[n, 2 * l] == 1:
97                 if memory.get((i, j)) is None:
98                     input[n, 2 * l] = -1
99                 else:
100                     input[n, 2 * l] = memory[(i, j)]
101             else:
102                 # print(f'@3 {memory=}')
103                 if memory.get((i, j)) is None:
104                     memory[(i, j)] = input[n, 2 * l]
105                 else:
106                     assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}"
107             # print(f'@1 {i=} {j=}')
108             d = input[n, 2 * l + 1].item()
109             i += (d + 1) % 2 * (d - 1)
110             j += d % 2 * (d - 2)
111             # print(f'@2 {i=} {j=}')
112
113
114 ######################################################################
115
116 if __name__ == "__main__":
117     for n in range(16):
118         descr = generate(nb=1, height=12, width=16)
119
120         print(nb_properties(descr, height=12, width=16))
121
122         with open(f"picoclvr_example_{n:02d}.txt", "w") as f:
123             for d in descr:
124                 f.write(f"{d}\n\n")
125
126         img = descr2img(descr, height=12, width=16)
127         if img.size(0) == 1:
128             img = F.pad(img, (1, 1, 1, 1), value=64)
129
130         torchvision.utils.save_image(
131             img / 255.0,
132             f"picoclvr_example_{n:02d}.png",
133             padding=1,
134             nrow=4,
135             pad_value=0.8,
136         )
137
138     import time
139
140     start_time = time.perf_counter()
141     descr = generate(nb=1000, height=12, width=16)
142     end_time = time.perf_counter()
143     print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
144
145 ######################################################################