3 import sys, math, time, argparse
5 import torch, torchvision
8 from torch.nn import functional as F
10 ######################################################################
12 parser = argparse.ArgumentParser(
13 description='Path-planning as denoising.',
14 formatter_class = argparse.ArgumentDefaultsHelpFormatter
17 parser.add_argument('--nb_epochs',
18 type = int, default = 25)
20 parser.add_argument('--batch_size',
21 type = int, default = 100)
23 parser.add_argument('--nb_residual_blocks',
24 type = int, default = 16)
26 parser.add_argument('--nb_channels',
27 type = int, default = 128)
29 parser.add_argument('--kernel_size',
30 type = int, default = 3)
32 parser.add_argument('--nb_for_train',
33 type = int, default = 100000)
35 parser.add_argument('--nb_for_test',
36 type = int, default = 10000)
38 parser.add_argument('--world_height',
39 type = int, default = 23)
41 parser.add_argument('--world_width',
42 type = int, default = 31)
44 parser.add_argument('--world_nb_walls',
45 type = int, default = 15)
47 parser.add_argument('--seed',
48 type = int, default = 0,
49 help = 'Random seed (default 0, < 0 is no seeding)')
51 ######################################################################
53 args = parser.parse_args()
56 torch.manual_seed(args.seed)
58 ######################################################################
62 log_file = open(f'path_{label}train.log', 'w')
64 ######################################################################
67 t = time.strftime('%Y%m%d-%H:%M:%S', time.localtime())
69 if log_file is not None:
70 log_file.write(s + '\n')
76 ######################################################################
79 def __init__(self, n):
86 u = self.t0 + ((t - self.t0) * self.n) // k
87 return time.strftime('%Y%m%d-%H:%M:%S', time.localtime(u))
91 ######################################################################
93 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95 log_string(f'device {device}')
97 ######################################################################
99 def create_maze(h = 11, w = 15, nb_walls = 10):
105 m = torch.zeros(h, w, dtype = torch.int64)
114 i1, i2, j = int((r[1] * h).item()), int((r[2] * h).item()), int((r[3] * w).item())
115 i1, i2, j = i1 - i1%2, i2 - i2%2, j - j%2
116 i1, i2 = min(i1, i2), max(i1, i2)
117 if i2 - i1 > 1 and i2 - i1 <= h/2 and m[i1:i2+1, j].sum() <= 1:
121 i, j1, j2 = int((r[1] * h).item()), int((r[2] * w).item()), int((r[3] * w).item())
122 i, j1, j2 = i - i%2, j1 - j1%2, j2 - j2%2
123 j1, j2 = min(j1, j2), max(j1, j2)
124 if j2 - j1 > 1 and j2 - j1 <= w/2 and m[i, j1:j2+1].sum() <= 1:
129 if a > 10 * nb_walls: a, k = 0, 0
135 ######################################################################
137 def random_free_position(walls):
138 p = torch.randperm(walls.numel())
139 k = p[walls.view(-1)[p] == 0][0].item()
140 return k//walls.size(1), k%walls.size(1)
142 def create_transitions(walls, nb):
143 trans = walls.new_zeros((9,) + walls.size())
144 t = torch.randint(4, (nb,))
145 i, j = random_free_position(walls)
147 for k in range(t.size(0)):
148 di, dj = [ (0, 1), (1, 0), (0, -1), (-1, 0) ][t[k]]
149 ip, jp = i + di, j + dj
150 if ip < 0 or ip >= walls.size(0) or \
151 jp < 0 or jp >= walls.size(1) or \
153 trans[t[k] + 4, i, j] += 1
155 trans[t[k], i, j] += 1
158 n = trans[0:8].sum(dim = 0, keepdim = True)
160 trans[0:8] = trans[0:8] / (n + (n == 0).long())
164 ######################################################################
166 def compute_distance(walls, i, j):
167 max_length = walls.numel()
168 dist = torch.full_like(walls, max_length)
171 pred_dist = torch.empty_like(dist)
174 pred_dist.copy_(dist)
177 dist[None, 1:-1, 0:-2],
178 dist[None, 2:, 1:-1],
179 dist[None, 1:-1, 2:],
180 dist[None, 0:-2, 1:-1]
182 0).min(dim = 0)[0] + 1
184 dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
185 dist = walls * max_length + (1 - walls) * dist
187 if dist.equal(pred_dist): return dist * (1 - walls)
189 ######################################################################
191 def compute_policy(walls, i, j):
192 distance = compute_distance(walls, i, j)
193 distance = distance + walls.numel() * walls
195 value = distance.new_full((4,) + distance.size(), walls.numel())
196 value[0, : , 1: ] = distance[ : , :-1]
197 value[1, : , :-1] = distance[ : , 1: ]
198 value[2, 1: , : ] = distance[ :-1, : ]
199 value[3, :-1, : ] = distance[1: , : ]
201 proba = (value.min(dim = 0)[0][None] == value).float()
202 proba = proba / proba.sum(dim = 0)[None]
203 proba = proba * (1 - walls)
207 ######################################################################
209 def create_maze_data(nb, h = 11, w = 17, nb_walls = 8, traj_length = 50):
210 input = torch.empty(nb, 10, h, w)
211 targets = torch.empty(nb, 2, h, w)
213 if type(traj_length) == tuple:
214 l = (torch.rand(nb) * (traj_length[1] - traj_length[0]) + traj_length[0]).long()
216 l = torch.full((nb,), traj_length).long()
221 if n%(max(10, nb//1000)) == 0:
222 log_string(f'{(100 * n)/nb:.02f}% ETA {eta.eta(n+1)}')
224 walls = create_maze(h, w, nb_walls)
225 trans = create_transitions(walls, l[n])
227 i, j = random_free_position(walls)
228 start = walls.new_zeros(walls.size())
230 dist = compute_distance(walls, i, j)
232 input[n] = torch.cat((trans, start[None]), 0)
233 targets[n] = torch.cat((walls[None], dist[None]), 0)
235 return input, targets
237 ######################################################################
239 def save_image(name, input, targets, output = None):
240 input, targets = input.cpu(), targets.cpu()
242 weight = torch.tensor(
248 ] ).t()[:, :, None, None]
250 # img_trans = F.conv2d(input[:, 0:5], weight)
251 # img_trans = img_trans / img_trans.max()
253 img_trans = 1 / input[:, 8:9].expand(-1, 3, -1, -1)
254 img_trans = 1 - img_trans / img_trans.max()
256 img_start = input[:, 9:10].expand(-1, 3, -1, -1)
257 img_start = 1 - img_start / img_start.max()
259 img_walls = targets[:, 0:1].expand(-1, 3, -1, -1)
260 img_walls = 1 - img_walls / img_walls.max()
262 # img_pi = F.conv2d(targets[:, 2:6], weight)
263 # img_pi = img_pi / img_pi.max()
265 img_dist = targets[:, 1:2].expand(-1, 3, -1, -1)
266 img_dist = img_dist / img_dist.max()
276 if output is not None:
277 output = output.cpu()
278 img_walls = output[:, 0:1].expand(-1, 3, -1, -1)
279 img_walls = 1 - img_walls / img_walls.max()
281 # img_pi = F.conv2d(output[:, 2:6].mul(100).softmax(dim = 1), weight)
282 # img_pi = img_pi / img_pi.max() * output[:, 0:2].softmax(dim = 1)[:, 0:1]
284 img_dist = output[:, 1:2].expand(-1, 3, -1, -1)
285 img_dist = img_dist / img_dist.max()
293 img_all = torch.cat(img, 1)
295 img_all = img_all.view(
296 img_all.size(0) * img_all.size(1),
302 torchvision.utils.save_image(
305 padding = 1, pad_value = 0.5, nrow = len(img)
308 log_string(f'Wrote {name}')
310 ######################################################################
312 class Net(nn.Module):
316 self.conv1 = nn.Conv2d( 6, nh, kernel_size = 5, padding = 2)
317 self.conv2 = nn.Conv2d(nh, nh, kernel_size = 5, padding = 2)
318 self.conv3 = nn.Conv2d(nh, nh, kernel_size = 5, padding = 2)
319 self.conv4 = nn.Conv2d(nh, 2, kernel_size = 5, padding = 2)
321 def forward(self, x):
322 x = F.relu(self.conv1(x))
323 x = F.relu(self.conv2(x))
324 x = F.relu(self.conv3(x))
328 ######################################################################
330 class ResNetBlock(nn.Module):
331 def __init__(self, nb_channels, kernel_size):
334 self.conv1 = nn.Conv2d(nb_channels, nb_channels,
335 kernel_size = kernel_size,
336 padding = (kernel_size - 1) // 2)
338 self.bn1 = nn.BatchNorm2d(nb_channels)
340 self.conv2 = nn.Conv2d(nb_channels, nb_channels,
341 kernel_size = kernel_size,
342 padding = (kernel_size - 1) // 2)
344 self.bn2 = nn.BatchNorm2d(nb_channels)
346 def forward(self, x):
347 y = F.relu(self.bn1(self.conv1(x)))
348 y = F.relu(x + self.bn2(self.conv2(y)))
351 class ResNet(nn.Module):
354 in_channels, out_channels,
355 nb_residual_blocks, nb_channels, kernel_size):
358 self.pre_process = nn.Sequential(
359 nn.Conv2d(in_channels, nb_channels,
360 kernel_size = kernel_size,
361 padding = (kernel_size - 1) // 2),
362 nn.BatchNorm2d(nb_channels),
363 nn.ReLU(inplace = True),
367 for k in range(nb_residual_blocks):
368 blocks.append(ResNetBlock(nb_channels, kernel_size))
370 self.resnet_blocks = nn.Sequential(*blocks)
372 self.post_process = nn.Conv2d(nb_channels, out_channels, kernel_size = 1)
374 def forward(self, x):
375 x = self.pre_process(x)
376 x = self.resnet_blocks(x)
377 x = self.post_process(x)
380 ######################################################################
382 data_filename = 'path.dat'
385 input, targets = torch.load(data_filename)
386 log_string('Data loaded.')
387 assert input.size(0) == args.nb_for_train + args.nb_for_test and \
388 input.size(1) == 10 and \
389 input.size(2) == args.world_height and \
390 input.size(3) == args.world_width and \
392 targets.size(0) == args.nb_for_train + args.nb_for_test and \
393 targets.size(1) == 2 and \
394 targets.size(2) == args.world_height and \
395 targets.size(3) == args.world_width
397 except FileNotFoundError:
398 log_string('Generating data.')
400 input, targets = create_maze_data(
401 nb = args.nb_for_train + args.nb_for_test,
402 h = args.world_height, w = args.world_width,
403 nb_walls = args.world_nb_walls,
404 traj_length = (100, 10000)
407 torch.save((input, targets), data_filename)
410 log_string('Error when loading data.')
413 ######################################################################
416 log_string(f'args.{n} {getattr(args, n)}')
419 in_channels = 10, out_channels = 2,
420 nb_residual_blocks = args.nb_residual_blocks,
421 nb_channels = args.nb_channels,
422 kernel_size = args.kernel_size
425 criterion = nn.MSELoss()
430 input, targets = input.to(device), targets.to(device)
432 train_input, train_targets = input[:args.nb_for_train], targets[:args.nb_for_train]
433 test_input, test_targets = input[args.nb_for_train:], targets[args.nb_for_train:]
435 mu, std = train_input.mean(), train_input.std()
436 train_input.sub_(mu).div_(std)
437 test_input.sub_(mu).div_(std)
439 ######################################################################
441 eta = ETA(args.nb_epochs)
443 for e in range(args.nb_epochs):
445 if e < args.nb_epochs // 2:
450 optimizer = torch.optim.Adam(model.parameters(), lr = lr)
454 for input, targets in zip(train_input.split(args.batch_size),
455 train_targets.split(args.batch_size)):
456 output = model(input)
458 loss = criterion(output, targets)
459 acc_train_loss += loss.item()
461 optimizer.zero_grad()
467 for input, targets in zip(test_input.split(args.batch_size),
468 test_targets.split(args.batch_size)):
469 output = model(input)
470 loss = criterion(output, targets)
471 test_loss += loss.item()
474 f'{e} acc_train_loss {acc_train_loss / (args.nb_for_train / args.batch_size)} test_loss {test_loss / (args.nb_for_test / args.batch_size)} ETA {eta.eta(e+1)}'
477 # save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8]))
478 # save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8]))
480 save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8])[:, 0:2])
481 save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8])[:, 0:2])
483 ######################################################################