Added the copyright comment + reformatted with black.
[path.git] / path.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import sys, math, time, argparse
9
10 import torch, torchvision
11
12 from torch import nn
13 from torch.nn import functional as F
14
15 ######################################################################
16
17 parser = argparse.ArgumentParser(
18     description="Path-planning as denoising.",
19     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
20 )
21
22 parser.add_argument("--nb_epochs", type=int, default=25)
23
24 parser.add_argument("--batch_size", type=int, default=100)
25
26 parser.add_argument("--nb_residual_blocks", type=int, default=16)
27
28 parser.add_argument("--nb_channels", type=int, default=128)
29
30 parser.add_argument("--kernel_size", type=int, default=3)
31
32 parser.add_argument("--nb_for_train", type=int, default=100000)
33
34 parser.add_argument("--nb_for_test", type=int, default=10000)
35
36 parser.add_argument("--world_height", type=int, default=23)
37
38 parser.add_argument("--world_width", type=int, default=31)
39
40 parser.add_argument("--world_nb_walls", type=int, default=15)
41
42 parser.add_argument(
43     "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
44 )
45
46 ######################################################################
47
48 args = parser.parse_args()
49
50 if args.seed >= 0:
51     torch.manual_seed(args.seed)
52
53 ######################################################################
54
55 label = ""
56
57 log_file = open(f"path_{label}train.log", "w")
58
59 ######################################################################
60
61
62 def log_string(s):
63     t = time.strftime("%Y%m%d-%H:%M:%S", time.localtime())
64     s = t + " - " + s
65     if log_file is not None:
66         log_file.write(s + "\n")
67         log_file.flush()
68
69     print(s)
70     sys.stdout.flush()
71
72
73 ######################################################################
74
75
76 class ETA:
77     def __init__(self, n):
78         self.n = n
79         self.t0 = time.time()
80
81     def eta(self, k):
82         if k > 0:
83             t = time.time()
84             u = self.t0 + ((t - self.t0) * self.n) // k
85             return time.strftime("%Y%m%d-%H:%M:%S", time.localtime(u))
86         else:
87             return "n.a."
88
89
90 ######################################################################
91
92 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
94 log_string(f"device {device}")
95
96 ######################################################################
97
98
99 def create_maze(h=11, w=15, nb_walls=10):
100     a, k = 0, 0
101
102     while k < nb_walls:
103         while True:
104             if a == 0:
105                 m = torch.zeros(h, w, dtype=torch.int64)
106                 m[0, :] = 1
107                 m[-1, :] = 1
108                 m[:, 0] = 1
109                 m[:, -1] = 1
110
111             r = torch.rand(4)
112
113             if r[0] <= 0.5:
114                 i1, i2, j = (
115                     int((r[1] * h).item()),
116                     int((r[2] * h).item()),
117                     int((r[3] * w).item()),
118                 )
119                 i1, i2, j = i1 - i1 % 2, i2 - i2 % 2, j - j % 2
120                 i1, i2 = min(i1, i2), max(i1, i2)
121                 if i2 - i1 > 1 and i2 - i1 <= h / 2 and m[i1 : i2 + 1, j].sum() <= 1:
122                     m[i1 : i2 + 1, j] = 1
123                     break
124             else:
125                 i, j1, j2 = (
126                     int((r[1] * h).item()),
127                     int((r[2] * w).item()),
128                     int((r[3] * w).item()),
129                 )
130                 i, j1, j2 = i - i % 2, j1 - j1 % 2, j2 - j2 % 2
131                 j1, j2 = min(j1, j2), max(j1, j2)
132                 if j2 - j1 > 1 and j2 - j1 <= w / 2 and m[i, j1 : j2 + 1].sum() <= 1:
133                     m[i, j1 : j2 + 1] = 1
134                     break
135             a += 1
136
137             if a > 10 * nb_walls:
138                 a, k = 0, 0
139
140         k += 1
141
142     return m
143
144
145 ######################################################################
146
147
148 def random_free_position(walls):
149     p = torch.randperm(walls.numel())
150     k = p[walls.view(-1)[p] == 0][0].item()
151     return k // walls.size(1), k % walls.size(1)
152
153
154 def create_transitions(walls, nb):
155     trans = walls.new_zeros((9,) + walls.size())
156     t = torch.randint(4, (nb,))
157     i, j = random_free_position(walls)
158
159     for k in range(t.size(0)):
160         di, dj = [(0, 1), (1, 0), (0, -1), (-1, 0)][t[k]]
161         ip, jp = i + di, j + dj
162         if (
163             ip < 0
164             or ip >= walls.size(0)
165             or jp < 0
166             or jp >= walls.size(1)
167             or walls[ip, jp] > 0
168         ):
169             trans[t[k] + 4, i, j] += 1
170         else:
171             trans[t[k], i, j] += 1
172             i, j = ip, jp
173
174     n = trans[0:8].sum(dim=0, keepdim=True)
175     trans[8:9] = n
176     trans[0:8] = trans[0:8] / (n + (n == 0).long())
177
178     return trans
179
180
181 ######################################################################
182
183
184 def compute_distance(walls, i, j):
185     max_length = walls.numel()
186     dist = torch.full_like(walls, max_length)
187
188     dist[i, j] = 0
189     pred_dist = torch.empty_like(dist)
190
191     while True:
192         pred_dist.copy_(dist)
193         d = (
194             torch.cat(
195                 (
196                     dist[None, 1:-1, 0:-2],
197                     dist[None, 2:, 1:-1],
198                     dist[None, 1:-1, 2:],
199                     dist[None, 0:-2, 1:-1],
200                 ),
201                 0,
202             ).min(dim=0)[0]
203             + 1
204         )
205
206         dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
207         dist = walls * max_length + (1 - walls) * dist
208
209         if dist.equal(pred_dist):
210             return dist * (1 - walls)
211
212
213 ######################################################################
214
215
216 def compute_policy(walls, i, j):
217     distance = compute_distance(walls, i, j)
218     distance = distance + walls.numel() * walls
219
220     value = distance.new_full((4,) + distance.size(), walls.numel())
221     value[0, :, 1:] = distance[:, :-1]
222     value[1, :, :-1] = distance[:, 1:]
223     value[2, 1:, :] = distance[:-1, :]
224     value[3, :-1, :] = distance[1:, :]
225
226     proba = (value.min(dim=0)[0][None] == value).float()
227     proba = proba / proba.sum(dim=0)[None]
228     proba = proba * (1 - walls)
229
230     return proba
231
232
233 ######################################################################
234
235
236 def create_maze_data(nb, h=11, w=17, nb_walls=8, traj_length=50):
237     input = torch.empty(nb, 10, h, w)
238     targets = torch.empty(nb, 2, h, w)
239
240     if type(traj_length) == tuple:
241         l = (torch.rand(nb) * (traj_length[1] - traj_length[0]) + traj_length[0]).long()
242     else:
243         l = torch.full((nb,), traj_length).long()
244
245     eta = ETA(nb)
246
247     for n in range(nb):
248         if n % (max(10, nb // 1000)) == 0:
249             log_string(f"{(100 * n)/nb:.02f}% ETA {eta.eta(n+1)}")
250
251         walls = create_maze(h, w, nb_walls)
252         trans = create_transitions(walls, l[n])
253
254         i, j = random_free_position(walls)
255         start = walls.new_zeros(walls.size())
256         start[i, j] = 1
257         dist = compute_distance(walls, i, j)
258
259         input[n] = torch.cat((trans, start[None]), 0)
260         targets[n] = torch.cat((walls[None], dist[None]), 0)
261
262     return input, targets
263
264
265 ######################################################################
266
267
268 def save_image(name, input, targets, output=None):
269     input, targets = input.cpu(), targets.cpu()
270
271     weight = torch.tensor(
272         [
273             [1.0, 0.0, 0.0],
274             [1.0, 1.0, 0.0],
275             [0.0, 1.0, 0.0],
276             [0.0, 0.0, 1.0],
277         ]
278     ).t()[:, :, None, None]
279
280     # img_trans = F.conv2d(input[:, 0:5], weight)
281     # img_trans = img_trans / img_trans.max()
282
283     img_trans = 1 / input[:, 8:9].expand(-1, 3, -1, -1)
284     img_trans = 1 - img_trans / img_trans.max()
285
286     img_start = input[:, 9:10].expand(-1, 3, -1, -1)
287     img_start = 1 - img_start / img_start.max()
288
289     img_walls = targets[:, 0:1].expand(-1, 3, -1, -1)
290     img_walls = 1 - img_walls / img_walls.max()
291
292     # img_pi = F.conv2d(targets[:, 2:6], weight)
293     # img_pi = img_pi / img_pi.max()
294
295     img_dist = targets[:, 1:2].expand(-1, 3, -1, -1)
296     img_dist = img_dist / img_dist.max()
297
298     img = (
299         img_start[:, None],
300         img_trans[:, None],
301         img_walls[:, None],
302         # img_pi[:, None],
303         img_dist[:, None],
304     )
305
306     if output is not None:
307         output = output.cpu()
308         img_walls = output[:, 0:1].expand(-1, 3, -1, -1)
309         img_walls = 1 - img_walls / img_walls.max()
310
311         # img_pi = F.conv2d(output[:, 2:6].mul(100).softmax(dim = 1), weight)
312         # img_pi = img_pi / img_pi.max() * output[:, 0:2].softmax(dim = 1)[:, 0:1]
313
314         img_dist = output[:, 1:2].expand(-1, 3, -1, -1)
315         img_dist = img_dist / img_dist.max()
316
317         img += (
318             img_walls[:, None],
319             img_dist[:, None],
320             # img_pi[:, None],
321         )
322
323     img_all = torch.cat(img, 1)
324
325     img_all = img_all.view(
326         img_all.size(0) * img_all.size(1),
327         img_all.size(2),
328         img_all.size(3),
329         img_all.size(4),
330     )
331
332     torchvision.utils.save_image(img_all, name, padding=1, pad_value=0.5, nrow=len(img))
333
334     log_string(f"Wrote {name}")
335
336
337 ######################################################################
338
339
340 class Net(nn.Module):
341     def __init__(self):
342         super().__init__()
343         nh = 128
344         self.conv1 = nn.Conv2d(6, nh, kernel_size=5, padding=2)
345         self.conv2 = nn.Conv2d(nh, nh, kernel_size=5, padding=2)
346         self.conv3 = nn.Conv2d(nh, nh, kernel_size=5, padding=2)
347         self.conv4 = nn.Conv2d(nh, 2, kernel_size=5, padding=2)
348
349     def forward(self, x):
350         x = F.relu(self.conv1(x))
351         x = F.relu(self.conv2(x))
352         x = F.relu(self.conv3(x))
353         x = self.conv4(x)
354         return x
355
356
357 ######################################################################
358
359
360 class ResNetBlock(nn.Module):
361     def __init__(self, nb_channels, kernel_size):
362         super().__init__()
363
364         self.conv1 = nn.Conv2d(
365             nb_channels,
366             nb_channels,
367             kernel_size=kernel_size,
368             padding=(kernel_size - 1) // 2,
369         )
370
371         self.bn1 = nn.BatchNorm2d(nb_channels)
372
373         self.conv2 = nn.Conv2d(
374             nb_channels,
375             nb_channels,
376             kernel_size=kernel_size,
377             padding=(kernel_size - 1) // 2,
378         )
379
380         self.bn2 = nn.BatchNorm2d(nb_channels)
381
382     def forward(self, x):
383         y = F.relu(self.bn1(self.conv1(x)))
384         y = F.relu(x + self.bn2(self.conv2(y)))
385         return y
386
387
388 class ResNet(nn.Module):
389     def __init__(
390         self, in_channels, out_channels, nb_residual_blocks, nb_channels, kernel_size
391     ):
392         super().__init__()
393
394         self.pre_process = nn.Sequential(
395             nn.Conv2d(
396                 in_channels,
397                 nb_channels,
398                 kernel_size=kernel_size,
399                 padding=(kernel_size - 1) // 2,
400             ),
401             nn.BatchNorm2d(nb_channels),
402             nn.ReLU(inplace=True),
403         )
404
405         blocks = []
406         for k in range(nb_residual_blocks):
407             blocks.append(ResNetBlock(nb_channels, kernel_size))
408
409         self.resnet_blocks = nn.Sequential(*blocks)
410
411         self.post_process = nn.Conv2d(nb_channels, out_channels, kernel_size=1)
412
413     def forward(self, x):
414         x = self.pre_process(x)
415         x = self.resnet_blocks(x)
416         x = self.post_process(x)
417         return x
418
419
420 ######################################################################
421
422 data_filename = "path.dat"
423
424 try:
425     input, targets = torch.load(data_filename)
426     log_string("Data loaded.")
427     assert (
428         input.size(0) == args.nb_for_train + args.nb_for_test
429         and input.size(1) == 10
430         and input.size(2) == args.world_height
431         and input.size(3) == args.world_width
432         and targets.size(0) == args.nb_for_train + args.nb_for_test
433         and targets.size(1) == 2
434         and targets.size(2) == args.world_height
435         and targets.size(3) == args.world_width
436     )
437
438 except FileNotFoundError:
439     log_string("Generating data.")
440
441     input, targets = create_maze_data(
442         nb=args.nb_for_train + args.nb_for_test,
443         h=args.world_height,
444         w=args.world_width,
445         nb_walls=args.world_nb_walls,
446         traj_length=(100, 10000),
447     )
448
449     torch.save((input, targets), data_filename)
450
451 except:
452     log_string("Error when loading data.")
453     exit(1)
454
455 ######################################################################
456
457 for n in vars(args):
458     log_string(f"args.{n} {getattr(args, n)}")
459
460 model = ResNet(
461     in_channels=10,
462     out_channels=2,
463     nb_residual_blocks=args.nb_residual_blocks,
464     nb_channels=args.nb_channels,
465     kernel_size=args.kernel_size,
466 )
467
468 criterion = nn.MSELoss()
469
470 model.to(device)
471 criterion.to(device)
472
473 input, targets = input.to(device), targets.to(device)
474
475 train_input, train_targets = input[: args.nb_for_train], targets[: args.nb_for_train]
476 test_input, test_targets = input[args.nb_for_train :], targets[args.nb_for_train :]
477
478 mu, std = train_input.mean(), train_input.std()
479 train_input.sub_(mu).div_(std)
480 test_input.sub_(mu).div_(std)
481
482 ######################################################################
483
484 eta = ETA(args.nb_epochs)
485
486 for e in range(args.nb_epochs):
487
488     if e < args.nb_epochs // 2:
489         lr = 1e-2
490     else:
491         lr = 1e-3
492
493     optimizer = torch.optim.Adam(model.parameters(), lr=lr)
494
495     acc_train_loss = 0.0
496
497     for input, targets in zip(
498         train_input.split(args.batch_size), train_targets.split(args.batch_size)
499     ):
500         output = model(input)
501
502         loss = criterion(output, targets)
503         acc_train_loss += loss.item()
504
505         optimizer.zero_grad()
506         loss.backward()
507         optimizer.step()
508
509     test_loss = 0.0
510
511     for input, targets in zip(
512         test_input.split(args.batch_size), test_targets.split(args.batch_size)
513     ):
514         output = model(input)
515         loss = criterion(output, targets)
516         test_loss += loss.item()
517
518     log_string(
519         f"{e} acc_train_loss {acc_train_loss / (args.nb_for_train / args.batch_size)} test_loss {test_loss / (args.nb_for_test / args.batch_size)} ETA {eta.eta(e+1)}"
520     )
521
522     # save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8]))
523     # save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8]))
524
525     save_image(
526         f"train_{e:04d}.png",
527         train_input[:8],
528         train_targets[:8],
529         model(train_input[:8])[:, 0:2],
530     )
531     save_image(
532         f"test_{e:04d}.png",
533         test_input[:8],
534         test_targets[:8],
535         model(test_input[:8])[:, 0:2],
536     )
537
538 ######################################################################