0392940d22b6af336e2f18103d18fa09aac62f6a
[culture.git] / world.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math, sys, tqdm
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17
18 colors = torch.tensor(
19     [
20         [255, 255, 255],
21         [0, 0, 0],
22         [255, 0, 0],
23         [0, 128, 0],
24         [0, 0, 255],
25         [255, 255, 0],
26         [192, 192, 192],
27     ]
28 )
29
30 token2char = "_X01234>"
31
32
33 def generate(
34     nb,
35     height,
36     width,
37     obj_length=6,
38     mask_height=3,
39     mask_width=3,
40     nb_obj=3,
41 ):
42     intact = torch.zeros(nb, height, width, dtype=torch.int64)
43     n = torch.arange(intact.size(0))
44
45     for n in range(nb):
46         for c in torch.randperm(colors.size(0) - 2)[:nb_obj] + 2:
47             z = intact[n].flatten()
48             m = (torch.rand(z.size()) * (z == 0)).argmax(dim=0)
49             i, j = m // width, m % width
50             vm = torch.randint(4, (1,))[0]
51             vi, vj = (vm // 2) * (2 * (vm % 2) - 1), (1 - vm // 2) * (2 * (vm % 2) - 1)
52             for l in range(obj_length):
53                 intact[n, i, j] = c
54                 i += vi
55                 j += vj
56                 if i < 0 or i >= height or j < 0 or j >= width or intact[n, i, j] != 0:
57                     i -= vi
58                     j -= vj
59                     vi, vj = -vj, vi
60                     i += vi
61                     j += vj
62                     if (
63                         i < 0
64                         or i >= height
65                         or j < 0
66                         or j >= width
67                         or intact[n, i, j] != 0
68                     ):
69                         break
70
71     masked = intact.clone()
72
73     for n in range(nb):
74         i = torch.randint(height - mask_height + 1, (1,))[0]
75         j = torch.randint(width - mask_width + 1, (1,))[0]
76         masked[n, i : i + mask_height, j : j + mask_width] = 1
77
78     return torch.cat(
79         [
80             masked.flatten(1),
81             torch.full((masked.size(0), 1), len(colors)),
82             intact.flatten(1),
83         ],
84         dim=1,
85     )
86
87
88 def sample2img(seq, height, width):
89     intact = seq[:, : height * width].reshape(-1, height, width)
90     masked = seq[:, height * width + 1 :].reshape(-1, height, width)
91     img_intact, img_masked = colors[intact], colors[masked]
92
93     img = torch.cat(
94         [
95             img_intact,
96             torch.full(
97                 (img_intact.size(0), img_intact.size(1), 1, img_intact.size(3)), 1
98             ),
99             img_masked,
100         ],
101         dim=2,
102     )
103
104     return img.permute(0, 3, 1, 2)
105
106
107 def seq2str(seq):
108     result = []
109     for s in seq:
110         result.append("".join([token2char[v] for v in s]))
111     return result
112
113
114 ######################################################################
115
116 if __name__ == "__main__":
117     import time
118
119     height, width = 6, 8
120     start_time = time.perf_counter()
121     seq = generate(nb=64, height=height, width=width)
122     delay = time.perf_counter() - start_time
123     print(f"{seq.size(0)/delay:02f} samples/s")
124
125     print(seq2str(seq[:4]))
126
127     img = sample2img(seq, height, width)
128     print(img.size())
129
130     torchvision.utils.save_image(img.float() / 255.0, "world.png", nrow=8, padding=2)