Update.
[picoclvr.git] / world.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9 import cairo
10
11
12 class Box:
13     def __init__(self, x, y, w, h, r, g, b):
14         self.x = x
15         self.y = y
16         self.w = w
17         self.h = h
18         self.r = r
19         self.g = g
20         self.b = b
21
22     def collision(self, scene):
23         for c in scene:
24             if (
25                 self is not c
26                 and max(self.x, c.x) <= min(self.x + self.w, c.x + c.w)
27                 and max(self.y, c.y) <= min(self.y + self.h, c.y + c.h)
28             ):
29                 return True
30         return False
31
32
33 def scene2tensor(xh, yh, scene, size=512):
34     width, height = size, size
35     pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
36     data = pixel_map.numpy()
37     surface = cairo.ImageSurface.create_for_data(
38         data, cairo.FORMAT_ARGB32, width, height
39     )
40
41     ctx = cairo.Context(surface)
42     ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
43
44     for b in scene:
45         ctx.move_to(b.x * size, b.y * size)
46         ctx.rel_line_to(b.w * size, 0)
47         ctx.rel_line_to(0, b.h * size)
48         ctx.rel_line_to(-b.w * size, 0)
49         ctx.close_path()
50         ctx.set_source_rgba(b.r, b.g, b.b, 1.0)
51         ctx.fill_preserve()
52         ctx.set_source_rgba(0, 0, 0, 1.0)
53         ctx.stroke()
54
55     hs = size * 0.05
56     ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0)
57     ctx.move_to(xh * size - hs / 2, yh * size - hs / 2)
58     ctx.rel_line_to(hs, 0)
59     ctx.rel_line_to(0, hs)
60     ctx.rel_line_to(-hs, 0)
61     ctx.close_path()
62     ctx.fill()
63
64     return pixel_map[None, :, :, :3].flip(-1).permute(0, 3, 1, 2).float() / 255
65
66
67 def random_scene():
68     scene = []
69     colors = [
70         (1.00, 0.00, 0.00),
71         (0.00, 1.00, 0.00),
72         (0.00, 0.00, 1.00),
73         (1.00, 1.00, 0.00),
74         (0.75, 0.75, 0.75),
75     ]
76
77     for k in range(10):
78         wh = torch.rand(2) * 0.2 + 0.2
79         xy = torch.rand(2) * (1 - wh)
80         c = colors[torch.randint(len(colors), (1,))]
81         b = Box(
82             xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2]
83         )
84         if not b.collision(scene):
85             scene.append(b)
86
87     return scene
88
89
90 def sequence(length=10):
91     delta = 0.1
92     effects = [
93         (False, 0, 0),
94         (False, delta, 0),
95         (False, 0, delta),
96         (False, -delta, 0),
97         (False, 0, -delta),
98         (True, delta, 0),
99         (True, 0, delta),
100         (True, -delta, 0),
101         (True, 0, -delta),
102     ]
103
104     while True:
105         scene = random_scene()
106         xh, yh = tuple(x.item() for x in torch.rand(2))
107
108         frame_start = scene2tensor(xh, yh, scene)
109
110         actions = torch.randint(len(effects), (length,))
111         change = False
112
113         for a in actions:
114             g, dx, dy = effects[a]
115             if g:
116                 for b in scene:
117                     if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
118                         x, y = b.x, b.y
119                         b.x += dx
120                         b.y += dy
121                         if (
122                             b.x < 0
123                             or b.y < 0
124                             or b.x + b.w > 1
125                             or b.y + b.h > 1
126                             or b.collision(scene)
127                         ):
128                             b.x, b.y = x, y
129                         else:
130                             xh += dx
131                             yh += dy
132                             change = True
133             else:
134                 x, y = xh, yh
135                 xh += dx
136                 yh += dy
137                 if xh < 0 or xh > 1 or yh < 0 or yh > 1:
138                     xh, yh = x, y
139
140         frame_end = scene2tensor(xh, yh, scene)
141         if change:
142             break
143
144     return frame_start, frame_end, actions
145
146
147 if __name__ == "__main__":
148     frame_start, frame_end, actions = sequence()
149     torchvision.utils.save_image(frame_start, "world_start.png")
150     torchvision.utils.save_image(frame_end, "world_end.png")