3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import sys, math, time, argparse
10 import torch, torchvision
13 from torch.nn import functional as F
15 ######################################################################
17 parser = argparse.ArgumentParser(
18 description="Path-planning as denoising.",
19 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
22 parser.add_argument("--nb_epochs", type=int, default=25)
24 parser.add_argument("--batch_size", type=int, default=100)
26 parser.add_argument("--nb_residual_blocks", type=int, default=16)
28 parser.add_argument("--nb_channels", type=int, default=128)
30 parser.add_argument("--kernel_size", type=int, default=3)
32 parser.add_argument("--nb_for_train", type=int, default=100000)
34 parser.add_argument("--nb_for_test", type=int, default=10000)
36 parser.add_argument("--world_height", type=int, default=23)
38 parser.add_argument("--world_width", type=int, default=31)
40 parser.add_argument("--world_nb_walls", type=int, default=15)
43 "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
46 ######################################################################
48 args = parser.parse_args()
51 torch.manual_seed(args.seed)
53 ######################################################################
57 log_file = open(f"path_{label}train.log", "w")
59 ######################################################################
63 t = time.strftime("%Y%m%d-%H:%M:%S", time.localtime())
65 if log_file is not None:
66 log_file.write(s + "\n")
73 ######################################################################
77 def __init__(self, n):
84 u = self.t0 + ((t - self.t0) * self.n) // k
85 return time.strftime("%Y%m%d-%H:%M:%S", time.localtime(u))
90 ######################################################################
92 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94 log_string(f"device {device}")
96 ######################################################################
99 def create_maze(h=11, w=15, nb_walls=10):
105 m = torch.zeros(h, w, dtype=torch.int64)
115 int((r[1] * h).item()),
116 int((r[2] * h).item()),
117 int((r[3] * w).item()),
119 i1, i2, j = i1 - i1 % 2, i2 - i2 % 2, j - j % 2
120 i1, i2 = min(i1, i2), max(i1, i2)
121 if i2 - i1 > 1 and i2 - i1 <= h / 2 and m[i1 : i2 + 1, j].sum() <= 1:
122 m[i1 : i2 + 1, j] = 1
126 int((r[1] * h).item()),
127 int((r[2] * w).item()),
128 int((r[3] * w).item()),
130 i, j1, j2 = i - i % 2, j1 - j1 % 2, j2 - j2 % 2
131 j1, j2 = min(j1, j2), max(j1, j2)
132 if j2 - j1 > 1 and j2 - j1 <= w / 2 and m[i, j1 : j2 + 1].sum() <= 1:
133 m[i, j1 : j2 + 1] = 1
137 if a > 10 * nb_walls:
145 ######################################################################
148 def random_free_position(walls):
149 p = torch.randperm(walls.numel())
150 k = p[walls.view(-1)[p] == 0][0].item()
151 return k // walls.size(1), k % walls.size(1)
154 def create_transitions(walls, nb):
155 trans = walls.new_zeros((9,) + walls.size())
156 t = torch.randint(4, (nb,))
157 i, j = random_free_position(walls)
159 for k in range(t.size(0)):
160 di, dj = [(0, 1), (1, 0), (0, -1), (-1, 0)][t[k]]
161 ip, jp = i + di, j + dj
164 or ip >= walls.size(0)
166 or jp >= walls.size(1)
169 trans[t[k] + 4, i, j] += 1
171 trans[t[k], i, j] += 1
174 n = trans[0:8].sum(dim=0, keepdim=True)
176 trans[0:8] = trans[0:8] / (n + (n == 0).long())
181 ######################################################################
184 def compute_distance(walls, i, j):
185 max_length = walls.numel()
186 dist = torch.full_like(walls, max_length)
189 pred_dist = torch.empty_like(dist)
192 pred_dist.copy_(dist)
196 dist[None, 1:-1, 0:-2],
197 dist[None, 2:, 1:-1],
198 dist[None, 1:-1, 2:],
199 dist[None, 0:-2, 1:-1],
206 dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
207 dist = walls * max_length + (1 - walls) * dist
209 if dist.equal(pred_dist):
210 return dist * (1 - walls)
213 ######################################################################
216 def compute_policy(walls, i, j):
217 distance = compute_distance(walls, i, j)
218 distance = distance + walls.numel() * walls
220 value = distance.new_full((4,) + distance.size(), walls.numel())
221 value[0, :, 1:] = distance[:, :-1]
222 value[1, :, :-1] = distance[:, 1:]
223 value[2, 1:, :] = distance[:-1, :]
224 value[3, :-1, :] = distance[1:, :]
226 proba = (value.min(dim=0)[0][None] == value).float()
227 proba = proba / proba.sum(dim=0)[None]
228 proba = proba * (1 - walls)
233 ######################################################################
236 def create_maze_data(nb, h=11, w=17, nb_walls=8, traj_length=50):
237 input = torch.empty(nb, 10, h, w)
238 targets = torch.empty(nb, 2, h, w)
240 if type(traj_length) == tuple:
241 l = (torch.rand(nb) * (traj_length[1] - traj_length[0]) + traj_length[0]).long()
243 l = torch.full((nb,), traj_length).long()
248 if n % (max(10, nb // 1000)) == 0:
249 log_string(f"{(100 * n)/nb:.02f}% ETA {eta.eta(n+1)}")
251 walls = create_maze(h, w, nb_walls)
252 trans = create_transitions(walls, l[n])
254 i, j = random_free_position(walls)
255 start = walls.new_zeros(walls.size())
257 dist = compute_distance(walls, i, j)
259 input[n] = torch.cat((trans, start[None]), 0)
260 targets[n] = torch.cat((walls[None], dist[None]), 0)
262 return input, targets
265 ######################################################################
268 def save_image(name, input, targets, output=None):
269 input, targets = input.cpu(), targets.cpu()
271 weight = torch.tensor(
278 ).t()[:, :, None, None]
280 # img_trans = F.conv2d(input[:, 0:5], weight)
281 # img_trans = img_trans / img_trans.max()
283 img_trans = 1 / input[:, 8:9].expand(-1, 3, -1, -1)
284 img_trans = 1 - img_trans / img_trans.max()
286 img_start = input[:, 9:10].expand(-1, 3, -1, -1)
287 img_start = 1 - img_start / img_start.max()
289 img_walls = targets[:, 0:1].expand(-1, 3, -1, -1)
290 img_walls = 1 - img_walls / img_walls.max()
292 # img_pi = F.conv2d(targets[:, 2:6], weight)
293 # img_pi = img_pi / img_pi.max()
295 img_dist = targets[:, 1:2].expand(-1, 3, -1, -1)
296 img_dist = img_dist / img_dist.max()
306 if output is not None:
307 output = output.cpu()
308 img_walls = output[:, 0:1].expand(-1, 3, -1, -1)
309 img_walls = 1 - img_walls / img_walls.max()
311 # img_pi = F.conv2d(output[:, 2:6].mul(100).softmax(dim = 1), weight)
312 # img_pi = img_pi / img_pi.max() * output[:, 0:2].softmax(dim = 1)[:, 0:1]
314 img_dist = output[:, 1:2].expand(-1, 3, -1, -1)
315 img_dist = img_dist / img_dist.max()
323 img_all = torch.cat(img, 1)
325 img_all = img_all.view(
326 img_all.size(0) * img_all.size(1),
332 torchvision.utils.save_image(img_all, name, padding=1, pad_value=0.5, nrow=len(img))
334 log_string(f"Wrote {name}")
337 ######################################################################
340 class Net(nn.Module):
344 self.conv1 = nn.Conv2d(6, nh, kernel_size=5, padding=2)
345 self.conv2 = nn.Conv2d(nh, nh, kernel_size=5, padding=2)
346 self.conv3 = nn.Conv2d(nh, nh, kernel_size=5, padding=2)
347 self.conv4 = nn.Conv2d(nh, 2, kernel_size=5, padding=2)
349 def forward(self, x):
350 x = F.relu(self.conv1(x))
351 x = F.relu(self.conv2(x))
352 x = F.relu(self.conv3(x))
357 ######################################################################
360 class ResNetBlock(nn.Module):
361 def __init__(self, nb_channels, kernel_size):
364 self.conv1 = nn.Conv2d(
367 kernel_size=kernel_size,
368 padding=(kernel_size - 1) // 2,
371 self.bn1 = nn.BatchNorm2d(nb_channels)
373 self.conv2 = nn.Conv2d(
376 kernel_size=kernel_size,
377 padding=(kernel_size - 1) // 2,
380 self.bn2 = nn.BatchNorm2d(nb_channels)
382 def forward(self, x):
383 y = F.relu(self.bn1(self.conv1(x)))
384 y = F.relu(x + self.bn2(self.conv2(y)))
388 class ResNet(nn.Module):
390 self, in_channels, out_channels, nb_residual_blocks, nb_channels, kernel_size
394 self.pre_process = nn.Sequential(
398 kernel_size=kernel_size,
399 padding=(kernel_size - 1) // 2,
401 nn.BatchNorm2d(nb_channels),
402 nn.ReLU(inplace=True),
406 for k in range(nb_residual_blocks):
407 blocks.append(ResNetBlock(nb_channels, kernel_size))
409 self.resnet_blocks = nn.Sequential(*blocks)
411 self.post_process = nn.Conv2d(nb_channels, out_channels, kernel_size=1)
413 def forward(self, x):
414 x = self.pre_process(x)
415 x = self.resnet_blocks(x)
416 x = self.post_process(x)
420 ######################################################################
422 data_filename = "path.dat"
425 input, targets = torch.load(data_filename)
426 log_string("Data loaded.")
428 input.size(0) == args.nb_for_train + args.nb_for_test
429 and input.size(1) == 10
430 and input.size(2) == args.world_height
431 and input.size(3) == args.world_width
432 and targets.size(0) == args.nb_for_train + args.nb_for_test
433 and targets.size(1) == 2
434 and targets.size(2) == args.world_height
435 and targets.size(3) == args.world_width
438 except FileNotFoundError:
439 log_string("Generating data.")
441 input, targets = create_maze_data(
442 nb=args.nb_for_train + args.nb_for_test,
445 nb_walls=args.world_nb_walls,
446 traj_length=(100, 10000),
449 torch.save((input, targets), data_filename)
452 log_string("Error when loading data.")
455 ######################################################################
458 log_string(f"args.{n} {getattr(args, n)}")
463 nb_residual_blocks=args.nb_residual_blocks,
464 nb_channels=args.nb_channels,
465 kernel_size=args.kernel_size,
468 criterion = nn.MSELoss()
473 input, targets = input.to(device), targets.to(device)
475 train_input, train_targets = input[: args.nb_for_train], targets[: args.nb_for_train]
476 test_input, test_targets = input[args.nb_for_train :], targets[args.nb_for_train :]
478 mu, std = train_input.mean(), train_input.std()
479 train_input.sub_(mu).div_(std)
480 test_input.sub_(mu).div_(std)
482 ######################################################################
484 eta = ETA(args.nb_epochs)
486 for e in range(args.nb_epochs):
488 if e < args.nb_epochs // 2:
493 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
497 for input, targets in zip(
498 train_input.split(args.batch_size), train_targets.split(args.batch_size)
500 output = model(input)
502 loss = criterion(output, targets)
503 acc_train_loss += loss.item()
505 optimizer.zero_grad()
511 for input, targets in zip(
512 test_input.split(args.batch_size), test_targets.split(args.batch_size)
514 output = model(input)
515 loss = criterion(output, targets)
516 test_loss += loss.item()
519 f"{e} acc_train_loss {acc_train_loss / (args.nb_for_train / args.batch_size)} test_loss {test_loss / (args.nb_for_test / args.batch_size)} ETA {eta.eta(e+1)}"
522 # save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8]))
523 # save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8]))
526 f"train_{e:04d}.png",
529 model(train_input[:8])[:, 0:2],
535 model(test_input[:8])[:, 0:2],
538 ######################################################################