Update.
[culture.git] / world.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17
18 colors = torch.tensor(
19     [
20         [255, 255, 255],
21         [0, 0, 0],
22         [255, 0, 0],
23         [0, 128, 0],
24         [0, 0, 255],
25         [255, 255, 0],
26         [192, 192, 192],
27     ]
28 )
29
30 token2char = "_X" + "".join([str(n) for n in range(len(colors) - 2)]) + ">"
31
32
33 def generate(
34     nb,
35     height,
36     width,
37     max_nb_obj=2,
38     nb_iterations=2,
39 ):
40     f_start = torch.zeros(nb, height, width, dtype=torch.int64)
41     f_end = torch.zeros(nb, height, width, dtype=torch.int64)
42     n = torch.arange(f_start.size(0))
43
44     for n in range(nb):
45         nb_fish = torch.randint(max_nb_obj, (1,)).item() + 1
46         for c in torch.randperm(colors.size(0) - 2)[:nb_fish].sort().values:
47             i, j = (
48                 torch.randint(height - 2, (1,))[0] + 1,
49                 torch.randint(width - 2, (1,))[0] + 1,
50             )
51             vm = torch.randint(4, (1,))[0]
52             vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (2 * (vm % 2) - 1)
53
54             f_start[n, i, j] = c + 2
55             f_start[n, i - vi, j - vj] = c + 2
56             f_start[n, i + vj, j - vi] = c + 2
57             f_start[n, i - vj, j + vi] = c + 2
58
59             for l in range(nb_iterations):
60                 i += vi
61                 j += vj
62                 if i < 0 or i >= height or j < 0 or j >= width:
63                     i -= vi
64                     j -= vj
65                     vi, vj = -vi, -vj
66                     i += vi
67                     j += vj
68
69             f_end[n, i, j] = c + 2
70             f_end[n, i - vi, j - vj] = c + 2
71             f_end[n, i + vj, j - vi] = c + 2
72             f_end[n, i - vj, j + vi] = c + 2
73
74     return torch.cat(
75         [
76             f_end.flatten(1),
77             torch.full((f_end.size(0), 1), len(colors)),
78             f_start.flatten(1),
79         ],
80         dim=1,
81     )
82
83
84 def sample2img(seq, height, width):
85     f_start = seq[:, : height * width].reshape(-1, height, width)
86     f_start = (f_start >= len(colors)).long() + (f_start < len(colors)).long() * f_start
87     f_end = seq[:, height * width + 1 :].reshape(-1, height, width)
88     f_end = (f_end >= len(colors)).long() + (f_end < len(colors)).long() * f_end
89
90     img_f_start, img_f_end = colors[f_start], colors[f_end]
91
92     img = torch.cat(
93         [
94             img_f_start,
95             torch.full(
96                 (img_f_start.size(0), img_f_start.size(1), 1, img_f_start.size(3)), 1
97             ),
98             img_f_end,
99         ],
100         dim=2,
101     )
102
103     return img.permute(0, 3, 1, 2)
104
105
106 def seq2str(seq):
107     result = []
108     for s in seq:
109         result.append("".join([token2char[v] for v in s]))
110     return result
111
112
113 ######################################################################
114
115 if __name__ == "__main__":
116     import time
117
118     height, width = 6, 8
119     start_time = time.perf_counter()
120     seq = generate(nb=64, height=height, width=width, max_nb_obj=3)
121     delay = time.perf_counter() - start_time
122     print(f"{seq.size(0)/delay:02f} samples/s")
123
124     print(seq2str(seq[:4]))
125
126     img = sample2img(seq, height, width)
127     print(img.size())
128
129     torchvision.utils.save_image(img.float() / 255.0, "world.png", nrow=8, padding=2)