Update.
[picoclvr.git] / world.py
1 #!/usr/bin/env python
2
3 import math, sys, tqdm
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9 import cairo
10
11
12 class Box:
13     nb_rgb_levels = 10
14
15     def __init__(self, x, y, w, h, r, g, b):
16         self.x = x
17         self.y = y
18         self.w = w
19         self.h = h
20         self.r = r
21         self.g = g
22         self.b = b
23
24     def collision(self, scene):
25         for c in scene:
26             if (
27                 self is not c
28                 and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w)
29                 and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h)
30             ):
31                 return True
32         return False
33
34
35 def scene2tensor(xh, yh, scene, size):
36     width, height = size, size
37     pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
38     data = pixel_map.numpy()
39     surface = cairo.ImageSurface.create_for_data(
40         data, cairo.FORMAT_ARGB32, width, height
41     )
42
43     ctx = cairo.Context(surface)
44     ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
45
46     for b in scene:
47         ctx.move_to(b.x * size, b.y * size)
48         ctx.rel_line_to(b.w * size, 0)
49         ctx.rel_line_to(0, b.h * size)
50         ctx.rel_line_to(-b.w * size, 0)
51         ctx.close_path()
52         ctx.set_source_rgba(
53             b.r / (Box.nb_rgb_levels - 1),
54             b.g / (Box.nb_rgb_levels - 1),
55             b.b / (Box.nb_rgb_levels - 1),
56             1.0,
57         )
58         ctx.fill()
59
60     hs = size * 0.1
61     ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0)
62     ctx.move_to(xh * size - hs / 2, yh * size - hs / 2)
63     ctx.rel_line_to(hs, 0)
64     ctx.rel_line_to(0, hs)
65     ctx.rel_line_to(-hs, 0)
66     ctx.close_path()
67     ctx.fill()
68
69     return (
70         pixel_map[None, :, :, :3]
71         .flip(-1)
72         .permute(0, 3, 1, 2)
73         .long()
74         .mul(Box.nb_rgb_levels)
75         .floor_divide(256)
76     )
77
78
79 def random_scene():
80     scene = []
81     colors = [
82         ((Box.nb_rgb_levels - 1), 0, 0),
83         (0, (Box.nb_rgb_levels - 1), 0),
84         (0, 0, (Box.nb_rgb_levels - 1)),
85         ((Box.nb_rgb_levels - 1), (Box.nb_rgb_levels - 1), 0),
86         (
87             (Box.nb_rgb_levels * 2) // 3,
88             (Box.nb_rgb_levels * 2) // 3,
89             (Box.nb_rgb_levels * 2) // 3,
90         ),
91     ]
92
93     for k in range(10):
94         wh = torch.rand(2) * 0.2 + 0.2
95         xy = torch.rand(2) * (1 - wh)
96         c = colors[torch.randint(len(colors), (1,))]
97         b = Box(
98             xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2]
99         )
100         if not b.collision(scene):
101             scene.append(b)
102
103     return scene
104
105
106 def generate_episode(nb_steps=10, size=64):
107     delta = 0.1
108     effects = [
109         (False, 0, 0),
110         (False, delta, 0),
111         (False, 0, delta),
112         (False, -delta, 0),
113         (False, 0, -delta),
114         (True, delta, 0),
115         (True, 0, delta),
116         (True, -delta, 0),
117         (True, 0, -delta),
118     ]
119
120     while True:
121         frames = []
122
123         scene = random_scene()
124         xh, yh = tuple(x.item() for x in torch.rand(2))
125
126         frames.append(scene2tensor(xh, yh, scene, size=size))
127
128         actions = torch.randint(len(effects), (nb_steps,))
129         change = False
130
131         for a in actions:
132             g, dx, dy = effects[a]
133             if g:
134                 for b in scene:
135                     if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
136                         x, y = b.x, b.y
137                         b.x += dx
138                         b.y += dy
139                         if (
140                             b.x < 0
141                             or b.y < 0
142                             or b.x + b.w > 1
143                             or b.y + b.h > 1
144                             or b.collision(scene)
145                         ):
146                             b.x, b.y = x, y
147                         else:
148                             xh += dx
149                             yh += dy
150                             change = True
151             else:
152                 x, y = xh, yh
153                 xh += dx
154                 yh += dy
155                 if xh < 0 or xh > 1 or yh < 0 or yh > 1:
156                     xh, yh = x, y
157
158             frames.append(scene2tensor(xh, yh, scene, size=size))
159
160         if change:
161             break
162
163     return frames, actions
164
165
166 ######################################################################
167
168
169 # ||x_i - c_j||^2 = ||x_i||^2 + ||c_j||^2 - 2<x_i, c_j>
170 def sq2matrix(x, c):
171     nx = x.pow(2).sum(1)
172     nc = c.pow(2).sum(1)
173     return nx[:, None] + nc[None, :] - 2 * x @ c.t()
174
175
176 def update_centroids(x, c, nb_min=1):
177     _, b = sq2matrix(x, c).min(1)
178     b.squeeze_()
179     nb_resets = 0
180
181     for k in range(0, c.size(0)):
182         i = b.eq(k).nonzero(as_tuple=False).squeeze()
183         if i.numel() >= nb_min:
184             c[k] = x.index_select(0, i).mean(0)
185         else:
186             n = torch.randint(x.size(0), (1,))
187             nb_resets += 1
188             c[k] = x[n]
189
190     return c, b, nb_resets
191
192
193 def kmeans(x, nb_centroids, nb_min=1):
194     if x.size(0) < nb_centroids * nb_min:
195         print("Not enough points!")
196         exit(1)
197
198     c = x[torch.randperm(x.size(0))[:nb_centroids]]
199     t = torch.full((x.size(0),), -1)
200     n = 0
201
202     while True:
203         c, u, nb_resets = update_centroids(x, c, nb_min)
204         n = n + 1
205         nb_changes = (u - t).sign().abs().sum() + nb_resets
206         t = u
207         if nb_changes == 0:
208             break
209
210     return c, t
211
212
213 ######################################################################
214
215
216 def patchify(x, factor, invert_size=None):
217     if invert_size is None:
218         return (
219             x.reshape(
220                 x.size(0),  # 0
221                 x.size(1),  # 1
222                 factor,  # 2
223                 x.size(2) // factor,  # 3
224                 factor,  # 4
225                 x.size(3) // factor,  # 5
226             )
227             .permute(0, 2, 4, 1, 3, 5)
228             .reshape(-1, x.size(1), x.size(2) // factor, x.size(3) // factor)
229         )
230     else:
231         return (
232             x.reshape(
233                 invert_size[0],  # 0
234                 factor,  # 1
235                 factor,  # 2
236                 invert_size[1],  # 3
237                 invert_size[2] // factor,  # 4
238                 invert_size[3] // factor,  # 5
239             )
240             .permute(0, 3, 1, 4, 2, 5)
241             .reshape(invert_size)
242         )
243
244
245 class Normalizer(nn.Module):
246     def __init__(self, mu, std):
247         super().__init__()
248         self.register_buffer("mu", mu)
249         self.register_buffer("log_var", 2 * torch.log(std))
250
251     def forward(self, x):
252         return (x - self.mu) / torch.exp(self.log_var / 2.0)
253
254
255 class SignSTE(nn.Module):
256     def __init__(self):
257         super().__init__()
258
259     def forward(self, x):
260         # torch.sign() takes three values
261         s = (x >= 0).float() * 2 - 1
262         if self.training:
263             u = torch.tanh(x)
264             return s + u - u.detach()
265         else:
266             return s
267
268
269 def train_encoder(
270     train_input,
271     test_input,
272     depth=2,
273     dim_hidden=48,
274     nb_bits_per_token=10,
275     lr_start=1e-3,
276     lr_end=1e-4,
277     nb_epochs=10,
278     batch_size=25,
279     device=torch.device("cpu"),
280 ):
281     mu, std = train_input.float().mean(), train_input.float().std()
282
283     def encoder_core(depth, dim):
284         l = [
285             [
286                 nn.Conv2d(
287                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
288                 ),
289                 nn.ReLU(),
290                 nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2),
291                 nn.ReLU(),
292             ]
293             for k in range(depth)
294         ]
295
296         return nn.Sequential(*[x for m in l for x in m])
297
298     def decoder_core(depth, dim):
299         l = [
300             [
301                 nn.ConvTranspose2d(
302                     dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2
303                 ),
304                 nn.ReLU(),
305                 nn.ConvTranspose2d(
306                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
307                 ),
308                 nn.ReLU(),
309             ]
310             for k in range(depth - 1, -1, -1)
311         ]
312
313         return nn.Sequential(*[x for m in l for x in m])
314
315     encoder = nn.Sequential(
316         Normalizer(mu, std),
317         nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1),
318         nn.ReLU(),
319         # 64x64
320         encoder_core(depth=depth, dim=dim_hidden),
321         # 8x8
322         nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1),
323     )
324
325     quantizer = SignSTE()
326
327     decoder = nn.Sequential(
328         nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1),
329         # 8x8
330         decoder_core(depth=depth, dim=dim_hidden),
331         # 64x64
332         nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1),
333     )
334
335     model = nn.Sequential(encoder, decoder)
336
337     nb_parameters = sum(p.numel() for p in model.parameters())
338
339     print(f"nb_parameters {nb_parameters}")
340
341     model.to(device)
342
343     g5x5 = torch.exp(-torch.tensor([[-2.0, -1.0, 0.0, 1.0, 2.0]]) ** 2 / 2)
344     g5x5 = (g5x5.t() @ g5x5).view(1, 1, 5, 5)
345     g5x5 = g5x5 / g5x5.sum()
346
347     for k in range(nb_epochs):
348         lr = math.exp(
349             math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
350         )
351         optimizer = torch.optim.Adam(model.parameters(), lr=lr)
352
353         acc_train_loss = 0.0
354
355         for input in train_input.split(batch_size):
356             z = encoder(input)
357             zq = z if k < 1 else quantizer(z)
358             output = decoder(zq)
359
360             output = output.reshape(
361                 output.size(0), -1, 3, output.size(2), output.size(3)
362             )
363
364             train_loss = F.cross_entropy(output, input)
365
366             acc_train_loss += train_loss.item() * input.size(0)
367
368             optimizer.zero_grad()
369             train_loss.backward()
370             optimizer.step()
371
372         acc_test_loss = 0.0
373
374         for input in test_input.split(batch_size):
375             z = encoder(input)
376             zq = z if k < 1 else quantizer(z)
377             output = decoder(zq)
378
379             output = output.reshape(
380                 output.size(0), -1, 3, output.size(2), output.size(3)
381             )
382
383             test_loss = F.cross_entropy(output, input)
384
385             acc_test_loss += test_loss.item() * input.size(0)
386
387         train_loss = acc_train_loss / train_input.size(0)
388         test_loss = acc_test_loss / test_input.size(0)
389
390         print(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
391         sys.stdout.flush()
392
393     return encoder, quantizer, decoder
394
395 def generate_episodes(nb):
396     all_frames = []
397     for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
398         frames, actions = generate_episode(nb_steps=31)
399         all_frames += [ frames[0], frames[-1] ]
400     return torch.cat(all_frames, 0).contiguous()
401
402 def create_data_and_processors(nb_train_samples, nb_test_samples):
403     train_input = generate_episodes(nb_train_samples)
404     test_input = generate_episodes(nb_test_samples)
405     encoder, quantizer, decoder = train_encoder(train_input, test_input, nb_epochs=2)
406
407     input = test_input[:64]
408
409     z = encoder(input.float())
410     height, width = z.size(2), z.size(3)
411     zq = quantizer(z).long()
412     pow2=(2**torch.arange(zq.size(1), device=zq.device))[None,None,:]
413     seq = (zq.permute(0,2,3,1).clamp(min=0).reshape(zq.size(0),-1,zq.size(1)) * pow2).sum(-1)
414     print(f"{seq.size()=}")
415
416     ZZ=zq
417
418     zq = ((seq[:,:,None] // pow2)%2)*2-1
419     zq = zq.reshape(zq.size(0), height, width, -1).permute(0,3,1,2)
420
421     print(ZZ[0])
422     print(zq[0])
423
424     print("CHECK", (ZZ-zq).abs().sum())
425
426     results = decoder(zq.float())
427     T = 0.1
428     results = results.reshape(
429         results.size(0), -1, 3, results.size(2), results.size(3)
430     ).permute(0, 2, 3, 4, 1)
431     results = torch.distributions.categorical.Categorical(logits=results / T).sample()
432
433
434     torchvision.utils.save_image(
435         input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=8
436     )
437
438     torchvision.utils.save_image(
439         results.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=8
440     )
441
442
443 ######################################################################
444
445 if __name__ == "__main__":
446     create_data_and_processors(250,100)
447
448     # train_input = generate_episodes(2500)
449     # test_input = generate_episodes(1000)
450
451     # encoder, quantizer, decoder = train_encoder(train_input, test_input)
452
453     # input = test_input[torch.randperm(test_input.size(0))[:64]]
454     # z = encoder(input.float())
455     # zq = quantizer(z)
456     # results = decoder(zq)
457
458     # T = 0.1
459     # results = torch.distributions.categorical.Categorical(logits=results / T).sample()