Update.
[picoclvr.git] / world.py
1 #!/usr/bin/env python
2
3 import math, sys, tqdm
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9 import cairo
10
11 ######################################################################
12
13
14 class Box:
15     nb_rgb_levels = 10
16
17     def __init__(self, x, y, w, h, r, g, b):
18         self.x = x
19         self.y = y
20         self.w = w
21         self.h = h
22         self.r = r
23         self.g = g
24         self.b = b
25
26     def collision(self, scene):
27         for c in scene:
28             if (
29                 self is not c
30                 and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w)
31                 and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h)
32             ):
33                 return True
34         return False
35
36
37 ######################################################################
38
39
40 class Normalizer(nn.Module):
41     def __init__(self, mu, std):
42         super().__init__()
43         self.register_buffer("mu", mu)
44         self.register_buffer("log_var", 2 * torch.log(std))
45
46     def forward(self, x):
47         return (x - self.mu) / torch.exp(self.log_var / 2.0)
48
49
50 class SignSTE(nn.Module):
51     def __init__(self):
52         super().__init__()
53
54     def forward(self, x):
55         # torch.sign() takes three values
56         s = (x >= 0).float() * 2 - 1
57
58         if self.training:
59             u = torch.tanh(x)
60             return s + u - u.detach()
61         else:
62             return s
63
64
65 class DiscreteSampler2d(nn.Module):
66     def __init__(self):
67         super().__init__()
68
69     def forward(self, x):
70         s = (x >= x.max(-3, keepdim=True).values).float()
71
72         if self.training:
73             u = x.softmax(dim=-3)
74             return s + u - u.detach()
75         else:
76             return s
77
78
79 def loss_H(binary_logits, h_threshold=1):
80     p = binary_logits.sigmoid().mean(0)
81     h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2)
82     h.clamp_(max=h_threshold)
83     return h_threshold - h.mean()
84
85
86 def train_encoder(
87     train_input,
88     test_input,
89     depth,
90     nb_bits_per_token,
91     dim_hidden=48,
92     lambda_entropy=0.0,
93     lr_start=1e-3,
94     lr_end=1e-4,
95     nb_epochs=10,
96     batch_size=25,
97     logger=None,
98     device=torch.device("cpu"),
99 ):
100     mu, std = train_input.float().mean(), train_input.float().std()
101
102     def encoder_core(depth, dim):
103         l = [
104             [
105                 nn.Conv2d(
106                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
107                 ),
108                 nn.ReLU(),
109                 nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2),
110                 nn.ReLU(),
111             ]
112             for k in range(depth)
113         ]
114
115         return nn.Sequential(*[x for m in l for x in m])
116
117     def decoder_core(depth, dim):
118         l = [
119             [
120                 nn.ConvTranspose2d(
121                     dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2
122                 ),
123                 nn.ReLU(),
124                 nn.ConvTranspose2d(
125                     dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
126                 ),
127                 nn.ReLU(),
128             ]
129             for k in range(depth - 1, -1, -1)
130         ]
131
132         return nn.Sequential(*[x for m in l for x in m])
133
134     encoder = nn.Sequential(
135         Normalizer(mu, std),
136         nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1),
137         nn.ReLU(),
138         # 64x64
139         encoder_core(depth=depth, dim=dim_hidden),
140         # 8x8
141         nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1),
142     )
143
144     quantizer = SignSTE()
145
146     decoder = nn.Sequential(
147         nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1),
148         # 8x8
149         decoder_core(depth=depth, dim=dim_hidden),
150         # 64x64
151         nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1),
152     )
153
154     model = nn.Sequential(encoder, decoder)
155
156     nb_parameters = sum(p.numel() for p in model.parameters())
157
158     logger(f"vqae nb_parameters {nb_parameters}")
159
160     model.to(device)
161
162     for k in range(nb_epochs):
163         lr = math.exp(
164             math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
165         )
166         optimizer = torch.optim.Adam(model.parameters(), lr=lr)
167
168         acc_train_loss = 0.0
169
170         for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"):
171             input = input.to(device)
172             z = encoder(input)
173             zq = quantizer(z)
174             output = decoder(zq)
175
176             output = output.reshape(
177                 output.size(0), -1, 3, output.size(2), output.size(3)
178             )
179
180             train_loss = F.cross_entropy(output, input)
181
182             if lambda_entropy > 0:
183                 train_loss = train_loss + lambda_entropy * loss_H(z, h_threshold=0.5)
184
185             acc_train_loss += train_loss.item() * input.size(0)
186
187             optimizer.zero_grad()
188             train_loss.backward()
189             optimizer.step()
190
191         acc_test_loss = 0.0
192
193         for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"):
194             input = input.to(device)
195             z = encoder(input)
196             zq = quantizer(z)
197             output = decoder(zq)
198
199             output = output.reshape(
200                 output.size(0), -1, 3, output.size(2), output.size(3)
201             )
202
203             test_loss = F.cross_entropy(output, input)
204
205             acc_test_loss += test_loss.item() * input.size(0)
206
207         train_loss = acc_train_loss / train_input.size(0)
208         test_loss = acc_test_loss / test_input.size(0)
209
210         logger(f"vqae train {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
211         sys.stdout.flush()
212
213     return encoder, quantizer, decoder
214
215
216 ######################################################################
217
218
219 def scene2tensor(xh, yh, scene, size):
220     width, height = size, size
221     pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
222     data = pixel_map.numpy()
223     surface = cairo.ImageSurface.create_for_data(
224         data, cairo.FORMAT_ARGB32, width, height
225     )
226
227     ctx = cairo.Context(surface)
228     ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
229
230     for b in scene:
231         ctx.move_to(b.x * size, b.y * size)
232         ctx.rel_line_to(b.w * size, 0)
233         ctx.rel_line_to(0, b.h * size)
234         ctx.rel_line_to(-b.w * size, 0)
235         ctx.close_path()
236         ctx.set_source_rgba(
237             b.r / (Box.nb_rgb_levels - 1),
238             b.g / (Box.nb_rgb_levels - 1),
239             b.b / (Box.nb_rgb_levels - 1),
240             1.0,
241         )
242         ctx.fill()
243
244     hs = size * 0.1
245     ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0)
246     ctx.move_to(xh * size - hs / 2, yh * size - hs / 2)
247     ctx.rel_line_to(hs, 0)
248     ctx.rel_line_to(0, hs)
249     ctx.rel_line_to(-hs, 0)
250     ctx.close_path()
251     ctx.fill()
252
253     return (
254         pixel_map[None, :, :, :3]
255         .flip(-1)
256         .permute(0, 3, 1, 2)
257         .long()
258         .mul(Box.nb_rgb_levels)
259         .floor_divide(256)
260     )
261
262
263 def random_scene(nb_insert_attempts=3):
264     scene = []
265     colors = [
266         ((Box.nb_rgb_levels - 1), 0, 0),
267         (0, (Box.nb_rgb_levels - 1), 0),
268         (0, 0, (Box.nb_rgb_levels - 1)),
269         ((Box.nb_rgb_levels - 1), (Box.nb_rgb_levels - 1), 0),
270         (
271             (Box.nb_rgb_levels * 2) // 3,
272             (Box.nb_rgb_levels * 2) // 3,
273             (Box.nb_rgb_levels * 2) // 3,
274         ),
275     ]
276
277     for k in range(nb_insert_attempts):
278         wh = torch.rand(2) * 0.2 + 0.2
279         xy = torch.rand(2) * (1 - wh)
280         c = colors[torch.randint(len(colors), (1,))]
281         b = Box(
282             xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2]
283         )
284         if not b.collision(scene):
285             scene.append(b)
286
287     return scene
288
289
290 def generate_episode(steps, size=64):
291     delta = 0.1
292     effects = [
293         (False, 0, 0),
294         (False, delta, 0),
295         (False, 0, delta),
296         (False, -delta, 0),
297         (False, 0, -delta),
298         (True, delta, 0),
299         (True, 0, delta),
300         (True, -delta, 0),
301         (True, 0, -delta),
302     ]
303
304     while True:
305         frames = []
306
307         scene = random_scene()
308         xh, yh = tuple(x.item() for x in torch.rand(2))
309
310         actions = torch.randint(len(effects), (len(steps),))
311         nb_changes = 0
312
313         for s, a in zip(steps, actions):
314             if s:
315                 frames.append(scene2tensor(xh, yh, scene, size=size))
316
317             grasp, dx, dy = effects[a]
318
319             if grasp:
320                 for b in scene:
321                     if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
322                         x, y = b.x, b.y
323                         b.x += dx
324                         b.y += dy
325                         if (
326                             b.x < 0
327                             or b.y < 0
328                             or b.x + b.w > 1
329                             or b.y + b.h > 1
330                             or b.collision(scene)
331                         ):
332                             b.x, b.y = x, y
333                         else:
334                             xh += dx
335                             yh += dy
336                             nb_changes += 1
337             else:
338                 x, y = xh, yh
339                 xh += dx
340                 yh += dy
341                 if xh < 0 or xh > 1 or yh < 0 or yh > 1:
342                     xh, yh = x, y
343
344         if nb_changes > len(steps) // 3:
345             break
346
347     return frames, actions
348
349
350 ######################################################################
351
352
353 def generate_episodes(nb, steps):
354     all_frames, all_actions = [], []
355     for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
356         frames, actions = generate_episode(steps)
357         all_frames += frames
358         all_actions += [actions[None, :]]
359     return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0)
360
361
362 def create_data_and_processors(
363     nb_train_samples,
364     nb_test_samples,
365     mode,
366     nb_steps,
367     depth=3,
368     nb_bits_per_token=8,
369     nb_epochs=10,
370     device=torch.device("cpu"),
371     device_storage=torch.device("cpu"),
372     logger=None,
373 ):
374     assert mode in ["first_last"]
375
376     if mode == "first_last":
377         steps = [True] + [False] * (nb_steps + 1) + [True]
378
379     if logger is None:
380         logger = lambda s: print(s)
381
382     train_input, train_actions = generate_episodes(nb_train_samples, steps)
383     train_input, train_actions = train_input.to(device_storage), train_actions.to(
384         device_storage
385     )
386     test_input, test_actions = generate_episodes(nb_test_samples, steps)
387     test_input, test_actions = test_input.to(device_storage), test_actions.to(
388         device_storage
389     )
390
391     encoder, quantizer, decoder = train_encoder(
392         train_input,
393         test_input,
394         depth=depth,
395         nb_bits_per_token=nb_bits_per_token,
396         lambda_entropy=1.0,
397         nb_epochs=nb_epochs,
398         logger=logger,
399         device=device,
400     )
401     encoder.train(False)
402     quantizer.train(False)
403     decoder.train(False)
404
405     z = encoder(train_input[:1].to(device))
406     pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :]
407     z_h, z_w = z.size(2), z.size(3)
408
409     logger(f"vqae input {train_input[0].size()} output {z[0].size()}")
410
411     def frame2seq(input, batch_size=25):
412         seq = []
413         p = pow2.to(device)
414         for x in input.split(batch_size):
415             x = x.to(device)
416             z = encoder(x)
417             ze_bool = (quantizer(z) >= 0).long()
418             output = (
419                 ze_bool.permute(0, 2, 3, 1).reshape(
420                     ze_bool.size(0), -1, ze_bool.size(1)
421                 )
422                 * p
423             ).sum(-1)
424
425             seq.append(output)
426
427         return torch.cat(seq, dim=0)
428
429     def seq2frame(input, batch_size=25, T=1e-2):
430         frames = []
431         p = pow2.to(device)
432         for seq in input.split(batch_size):
433             seq = seq.to(device)
434             zd_bool = (seq[:, :, None] // p) % 2
435             zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2)
436             logits = decoder(zd_bool * 2.0 - 1.0)
437             logits = logits.reshape(
438                 logits.size(0), -1, 3, logits.size(2), logits.size(3)
439             ).permute(0, 2, 3, 4, 1)
440             output = torch.distributions.categorical.Categorical(
441                 logits=logits / T
442             ).sample()
443
444             frames.append(output)
445
446         return torch.cat(frames, dim=0)
447
448     return train_input, train_actions, test_input, test_actions, frame2seq, seq2frame
449
450
451 ######################################################################
452
453 if __name__ == "__main__":
454     (
455         train_input,
456         train_actions,
457         test_input,
458         test_actions,
459         frame2seq,
460         seq2frame,
461     ) = create_data_and_processors(
462         25000,
463         1000,
464         nb_epochs=5,
465         mode="first_last",
466         nb_steps=20,
467     )
468
469     input = test_input[:256]
470
471     seq = frame2seq(input)
472     output = seq2frame(seq)
473
474     torchvision.utils.save_image(
475         input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=16
476     )
477
478     torchvision.utils.save_image(
479         output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=16
480     )