Update.
[picoclvr.git] / world.py
1 #!/usr/bin/env python
2
3 import math, sys, tqdm
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9 import cairo
10
11 ######################################################################
12
13
14 class Box:
15     nb_rgb_levels = 10
16
17     def __init__(self, x, y, w, h, r, g, b):
18         self.x = x
19         self.y = y
20         self.w = w
21         self.h = h
22         self.r = r
23         self.g = g
24         self.b = b
25
26     def collision(self, scene):
27         for c in scene:
28             if (
29                 self is not c
30                 and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w)
31                 and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h)
32             ):
33                 return True
34         return False
35
36
37 ######################################################################
38
39
40 class Normalizer(nn.Module):
41     def __init__(self, mu, std):
42         super().__init__()
43         self.register_buffer("mu", mu)
44         self.register_buffer("log_var", 2 * torch.log(std))
45
46     def forward(self, x):
47         return (x - self.mu) / torch.exp(self.log_var / 2.0)
48
49
50 class SignSTE(nn.Module):
51     def __init__(self):
52         super().__init__()
53
54     def forward(self, x):
55         # torch.sign() takes three values
56         s = (x >= 0).float() * 2 - 1
57
58         if self.training:
59             u = torch.tanh(x)
60             return s + u - u.detach()
61         else:
62             return s
63
64 class DiscreteSampler2d(nn.Module):
65     def __init__(self):
66         super().__init__()
67
68     def forward(self, x):
69         s = (x >= x.max(-3,keepdim=True).values).float()
70
71         if self.training:
72             u = x.softmax(dim=-3)
73             return s + u - u.detach()
74         else:
75             return s
76
77
78 def loss_H(binary_logits, h_threshold=1):
79     p = binary_logits.sigmoid().mean(0)
80     h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2)
81     h.clamp_(max=h_threshold)
82     return h_threshold - h.mean()
83
84
85 def train_encoder(
86     train_input,
87     test_input,
88     depth=2,
89     dim_hidden=48,
90     nb_bits_per_token=8,
91     lambda_entropy=0.0,
92     lr_start=1e-3,
93     lr_end=1e-4,
94     nb_epochs=10,
95     batch_size=25,
96     logger=None,
97     device=torch.device("cpu"),
98 ):
99     if logger is None:
100         logger = lambda s: print(s)
101
102     mu, std = train_input.float().mean(), train_input.float().std()
103
104     def encoder_core(depth, dim):
105         l = [
106             [
107                 nn.Conv2d(
108                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
109                 ),
110                 nn.ReLU(),
111                 nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2),
112                 nn.ReLU(),
113             ]
114             for k in range(depth)
115         ]
116
117         return nn.Sequential(*[x for m in l for x in m])
118
119     def decoder_core(depth, dim):
120         l = [
121             [
122                 nn.ConvTranspose2d(
123                     dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2
124                 ),
125                 nn.ReLU(),
126                 nn.ConvTranspose2d(
127                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
128                 ),
129                 nn.ReLU(),
130             ]
131             for k in range(depth - 1, -1, -1)
132         ]
133
134         return nn.Sequential(*[x for m in l for x in m])
135
136     encoder = nn.Sequential(
137         Normalizer(mu, std),
138         nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1),
139         nn.ReLU(),
140         # 64x64
141         encoder_core(depth=depth, dim=dim_hidden),
142         # 8x8
143         nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1),
144     )
145
146     quantizer = SignSTE()
147
148     decoder = nn.Sequential(
149         nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1),
150         # 8x8
151         decoder_core(depth=depth, dim=dim_hidden),
152         # 64x64
153         nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1),
154     )
155
156     model = nn.Sequential(encoder, decoder)
157
158     nb_parameters = sum(p.numel() for p in model.parameters())
159
160     logger(f"nb_parameters {nb_parameters}")
161
162     model.to(device)
163
164     for k in range(nb_epochs):
165         lr = math.exp(
166             math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
167         )
168         optimizer = torch.optim.Adam(model.parameters(), lr=lr)
169
170         acc_train_loss = 0.0
171
172         for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"):
173             input = input.to(device)
174             z = encoder(input)
175             zq = quantizer(z)
176             output = decoder(zq)
177
178             output = output.reshape(
179                 output.size(0), -1, 3, output.size(2), output.size(3)
180             )
181
182             train_loss = F.cross_entropy(output, input)
183
184             if lambda_entropy > 0:
185                 train_loss = train_loss + lambda_entropy * loss_H(z, h_threshold=0.5)
186
187             acc_train_loss += train_loss.item() * input.size(0)
188
189             optimizer.zero_grad()
190             train_loss.backward()
191             optimizer.step()
192
193         acc_test_loss = 0.0
194
195         for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"):
196             input = input.to(device)
197             z = encoder(input)
198             zq = quantizer(z)
199             output = decoder(zq)
200
201             output = output.reshape(
202                 output.size(0), -1, 3, output.size(2), output.size(3)
203             )
204
205             test_loss = F.cross_entropy(output, input)
206
207             acc_test_loss += test_loss.item() * input.size(0)
208
209         train_loss = acc_train_loss / train_input.size(0)
210         test_loss = acc_test_loss / test_input.size(0)
211
212         logger(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
213         sys.stdout.flush()
214
215     return encoder, quantizer, decoder
216
217
218 ######################################################################
219
220
221 def scene2tensor(xh, yh, scene, size):
222     width, height = size, size
223     pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
224     data = pixel_map.numpy()
225     surface = cairo.ImageSurface.create_for_data(
226         data, cairo.FORMAT_ARGB32, width, height
227     )
228
229     ctx = cairo.Context(surface)
230     ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
231
232     for b in scene:
233         ctx.move_to(b.x * size, b.y * size)
234         ctx.rel_line_to(b.w * size, 0)
235         ctx.rel_line_to(0, b.h * size)
236         ctx.rel_line_to(-b.w * size, 0)
237         ctx.close_path()
238         ctx.set_source_rgba(
239             b.r / (Box.nb_rgb_levels - 1),
240             b.g / (Box.nb_rgb_levels - 1),
241             b.b / (Box.nb_rgb_levels - 1),
242             1.0,
243         )
244         ctx.fill()
245
246     hs = size * 0.1
247     ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0)
248     ctx.move_to(xh * size - hs / 2, yh * size - hs / 2)
249     ctx.rel_line_to(hs, 0)
250     ctx.rel_line_to(0, hs)
251     ctx.rel_line_to(-hs, 0)
252     ctx.close_path()
253     ctx.fill()
254
255     return (
256         pixel_map[None, :, :, :3]
257         .flip(-1)
258         .permute(0, 3, 1, 2)
259         .long()
260         .mul(Box.nb_rgb_levels)
261         .floor_divide(256)
262     )
263
264
265 def random_scene(nb_insert_attempts=3):
266     scene = []
267     colors = [
268         ((Box.nb_rgb_levels - 1), 0, 0),
269         (0, (Box.nb_rgb_levels - 1), 0),
270         (0, 0, (Box.nb_rgb_levels - 1)),
271         ((Box.nb_rgb_levels - 1), (Box.nb_rgb_levels - 1), 0),
272         (
273             (Box.nb_rgb_levels * 2) // 3,
274             (Box.nb_rgb_levels * 2) // 3,
275             (Box.nb_rgb_levels * 2) // 3,
276         ),
277     ]
278
279     for k in range(nb_insert_attempts):
280         wh = torch.rand(2) * 0.2 + 0.2
281         xy = torch.rand(2) * (1 - wh)
282         c = colors[torch.randint(len(colors), (1,))]
283         b = Box(
284             xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2]
285         )
286         if not b.collision(scene):
287             scene.append(b)
288
289     return scene
290
291
292 def generate_episode(steps, size=64):
293     delta = 0.1
294     effects = [
295         (False, 0, 0),
296         (False, delta, 0),
297         (False, 0, delta),
298         (False, -delta, 0),
299         (False, 0, -delta),
300         (True, delta, 0),
301         (True, 0, delta),
302         (True, -delta, 0),
303         (True, 0, -delta),
304     ]
305
306     while True:
307         frames = []
308
309         scene = random_scene()
310         xh, yh = tuple(x.item() for x in torch.rand(2))
311
312         actions = torch.randint(len(effects), (len(steps),))
313         nb_changes = 0
314
315         for s, a in zip(steps, actions):
316             if s:
317                 frames.append(scene2tensor(xh, yh, scene, size=size))
318
319             grasp, dx, dy = effects[a]
320
321             if grasp:
322                 for b in scene:
323                     if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
324                         x, y = b.x, b.y
325                         b.x += dx
326                         b.y += dy
327                         if (
328                             b.x < 0
329                             or b.y < 0
330                             or b.x + b.w > 1
331                             or b.y + b.h > 1
332                             or b.collision(scene)
333                         ):
334                             b.x, b.y = x, y
335                         else:
336                             xh += dx
337                             yh += dy
338                             nb_changes += 1
339             else:
340                 x, y = xh, yh
341                 xh += dx
342                 yh += dy
343                 if xh < 0 or xh > 1 or yh < 0 or yh > 1:
344                     xh, yh = x, y
345
346         if nb_changes > len(steps) // 3:
347             break
348
349     return frames, actions
350
351
352 ######################################################################
353
354
355 def generate_episodes(nb, steps):
356     all_frames, all_actions = [], []
357     for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
358         frames, actions = generate_episode(steps)
359         all_frames += frames
360         all_actions += [actions[None, :]]
361     return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0)
362
363
364 def create_data_and_processors(
365     nb_train_samples,
366     nb_test_samples,
367     mode,
368     nb_steps,
369     nb_epochs=10,
370     device=torch.device("cpu"),
371     device_storage=torch.device("cpu"),
372     logger=None,
373 ):
374     assert mode in ["first_last"]
375
376     if mode == "first_last":
377         steps = [True] + [False] * (nb_steps + 1) + [True]
378
379     train_input, train_actions = generate_episodes(nb_train_samples, steps)
380     train_input, train_actions = train_input.to(device_storage), train_actions.to(
381         device_storage
382     )
383     test_input, test_actions = generate_episodes(nb_test_samples, steps)
384     test_input, test_actions = test_input.to(device_storage), test_actions.to(
385         device_storage
386     )
387
388     encoder, quantizer, decoder = train_encoder(
389         train_input,
390         test_input,
391         lambda_entropy=1.0,
392         nb_epochs=nb_epochs,
393         logger=logger,
394         device=device,
395     )
396     encoder.train(False)
397     quantizer.train(False)
398     decoder.train(False)
399
400     z = encoder(train_input[:1].to(device))
401     pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :]
402     z_h, z_w = z.size(2), z.size(3)
403
404     def frame2seq(input, batch_size=25):
405         seq = []
406         p = pow2.to(device)
407         for x in input.split(batch_size):
408             x = x.to(device)
409             z = encoder(x)
410             ze_bool = (quantizer(z) >= 0).long()
411             output = (
412                 ze_bool.permute(0, 2, 3, 1).reshape(
413                     ze_bool.size(0), -1, ze_bool.size(1)
414                 )
415                 * p
416             ).sum(-1)
417
418             seq.append(output)
419
420         return torch.cat(seq, dim=0)
421
422     def seq2frame(input, batch_size=25, T=1e-2):
423         frames = []
424         p = pow2.to(device)
425         for seq in input.split(batch_size):
426             seq = seq.to(device)
427             zd_bool = (seq[:, :, None] // p) % 2
428             zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2)
429             logits = decoder(zd_bool * 2.0 - 1.0)
430             logits = logits.reshape(
431                 logits.size(0), -1, 3, logits.size(2), logits.size(3)
432             ).permute(0, 2, 3, 4, 1)
433             output = torch.distributions.categorical.Categorical(
434                 logits=logits / T
435             ).sample()
436
437             frames.append(output)
438
439         return torch.cat(frames, dim=0)
440
441     return train_input, train_actions, test_input, test_actions, frame2seq, seq2frame
442
443
444 ######################################################################
445
446 if __name__ == "__main__":
447     (
448         train_input,
449         train_actions,
450         test_input,
451         test_actions,
452         frame2seq,
453         seq2frame,
454     ) = create_data_and_processors(
455         25000, 1000,
456         nb_epochs=5,
457         mode="first_last",
458         nb_steps=20,
459     )
460
461     input = test_input[:256]
462
463     seq = frame2seq(input)
464     output = seq2frame(seq)
465
466     torchvision.utils.save_image(
467         input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=16
468     )
469
470     torchvision.utils.save_image(
471         output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=16
472     )