--- /dev/null
+#!/usr/bin/env python
+
+import sys, math, time, argparse
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+parser = argparse.ArgumentParser(
+ description='Path-planning as denoising.',
+ formatter_class = argparse.ArgumentDefaultsHelpFormatter
+)
+
+parser.add_argument('--nb_epochs',
+ type = int, default = 25)
+
+parser.add_argument('--batch_size',
+ type = int, default = 100)
+
+parser.add_argument('--nb_residual_blocks',
+ type = int, default = 16)
+
+parser.add_argument('--nb_channels',
+ type = int, default = 128)
+
+parser.add_argument('--kernel_size',
+ type = int, default = 3)
+
+parser.add_argument('--nb_for_train',
+ type = int, default = 100000)
+
+parser.add_argument('--nb_for_test',
+ type = int, default = 10000)
+
+parser.add_argument('--world_height',
+ type = int, default = 23)
+
+parser.add_argument('--world_width',
+ type = int, default = 31)
+
+parser.add_argument('--world_nb_walls',
+ type = int, default = 15)
+
+parser.add_argument('--seed',
+ type = int, default = 0,
+ help = 'Random seed (default 0, < 0 is no seeding)')
+
+######################################################################
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+ torch.manual_seed(args.seed)
+
+######################################################################
+
+label=''
+
+log_file = open(f'path_{label}train.log', 'w')
+
+######################################################################
+
+def log_string(s):
+ t = time.strftime('%Y%m%d-%H:%M:%S', time.localtime())
+ s = t + ' - ' + s
+ if log_file is not None:
+ log_file.write(s + '\n')
+ log_file.flush()
+
+ print(s)
+ sys.stdout.flush()
+
+######################################################################
+
+class ETA:
+ def __init__(self, n):
+ self.n = n
+ self.t0 = time.time()
+
+ def eta(self, k):
+ if k > 0:
+ t = time.time()
+ u = self.t0 + ((t - self.t0) * self.n) // k
+ return time.strftime('%Y%m%d-%H:%M:%S', time.localtime(u))
+ else:
+ return "n.a."
+
+######################################################################
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+log_string(f'device {device}')
+
+######################################################################
+
+def create_maze(h = 11, w = 15, nb_walls = 10):
+ a, k = 0, 0
+
+ while k < nb_walls:
+ while True:
+ if a == 0:
+ m = torch.zeros(h, w, dtype = torch.int64)
+ m[ 0, :] = 1
+ m[-1, :] = 1
+ m[ :, 0] = 1
+ m[ :, -1] = 1
+
+ r = torch.rand(4)
+
+ if r[0] <= 0.5:
+ i1, i2, j = int((r[1] * h).item()), int((r[2] * h).item()), int((r[3] * w).item())
+ i1, i2, j = i1 - i1%2, i2 - i2%2, j - j%2
+ i1, i2 = min(i1, i2), max(i1, i2)
+ if i2 - i1 > 1 and i2 - i1 <= h/2 and m[i1:i2+1, j].sum() <= 1:
+ m[i1:i2+1, j] = 1
+ break
+ else:
+ i, j1, j2 = int((r[1] * h).item()), int((r[2] * w).item()), int((r[3] * w).item())
+ i, j1, j2 = i - i%2, j1 - j1%2, j2 - j2%2
+ j1, j2 = min(j1, j2), max(j1, j2)
+ if j2 - j1 > 1 and j2 - j1 <= w/2 and m[i, j1:j2+1].sum() <= 1:
+ m[i, j1:j2+1] = 1
+ break
+ a += 1
+
+ if a > 10 * nb_walls: a, k = 0, 0
+
+ k += 1
+
+ return m
+
+######################################################################
+
+def random_free_position(walls):
+ p = torch.randperm(walls.numel())
+ k = p[walls.view(-1)[p] == 0][0].item()
+ return k//walls.size(1), k%walls.size(1)
+
+def create_transitions(walls, nb):
+ trans = walls.new_zeros((9,) + walls.size())
+ t = torch.randint(4, (nb,))
+ i, j = random_free_position(walls)
+
+ for k in range(t.size(0)):
+ di, dj = [ (0, 1), (1, 0), (0, -1), (-1, 0) ][t[k]]
+ ip, jp = i + di, j + dj
+ if ip < 0 or ip >= walls.size(0) or \
+ jp < 0 or jp >= walls.size(1) or \
+ walls[ip, jp] > 0:
+ trans[t[k] + 4, i, j] += 1
+ else:
+ trans[t[k], i, j] += 1
+ i, j = ip, jp
+
+ n = trans[0:8].sum(dim = 0, keepdim = True)
+ trans[8:9] = n
+ trans[0:8] = trans[0:8] / (n + (n == 0).long())
+
+ return trans
+
+######################################################################
+
+def compute_distance(walls, i, j):
+ max_length = walls.numel()
+ dist = torch.full_like(walls, max_length)
+
+ dist[i, j] = 0
+ pred_dist = torch.empty_like(dist)
+
+ while True:
+ pred_dist.copy_(dist)
+ d = torch.cat(
+ (
+ dist[None, 1:-1, 0:-2],
+ dist[None, 2:, 1:-1],
+ dist[None, 1:-1, 2:],
+ dist[None, 0:-2, 1:-1]
+ ),
+ 0).min(dim = 0)[0] + 1
+
+ dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
+ dist = walls * max_length + (1 - walls) * dist
+
+ if dist.equal(pred_dist): return dist * (1 - walls)
+
+######################################################################
+
+def compute_policy(walls, i, j):
+ distance = compute_distance(walls, i, j)
+ distance = distance + walls.numel() * walls
+
+ value = distance.new_full((4,) + distance.size(), walls.numel())
+ value[0, : , 1: ] = distance[ : , :-1]
+ value[1, : , :-1] = distance[ : , 1: ]
+ value[2, 1: , : ] = distance[ :-1, : ]
+ value[3, :-1, : ] = distance[1: , : ]
+
+ proba = (value.min(dim = 0)[0][None] == value).float()
+ proba = proba / proba.sum(dim = 0)[None]
+ proba = proba * (1 - walls)
+
+ return proba
+
+######################################################################
+
+def create_maze_data(nb, h = 11, w = 17, nb_walls = 8, traj_length = 50):
+ input = torch.empty(nb, 10, h, w)
+ targets = torch.empty(nb, 2, h, w)
+
+ if type(traj_length) == tuple:
+ l = (torch.rand(nb) * (traj_length[1] - traj_length[0]) + traj_length[0]).long()
+ else:
+ l = torch.full((nb,), traj_length).long()
+
+ eta = ETA(nb)
+
+ for n in range(nb):
+ if n%(max(10, nb//1000)) == 0:
+ log_string(f'{(100 * n)/nb:.02f}% ETA {eta.eta(n+1)}')
+
+ walls = create_maze(h, w, nb_walls)
+ trans = create_transitions(walls, l[n])
+
+ i, j = random_free_position(walls)
+ start = walls.new_zeros(walls.size())
+ start[i, j] = 1
+ dist = compute_distance(walls, i, j)
+
+ input[n] = torch.cat((trans, start[None]), 0)
+ targets[n] = torch.cat((walls[None], dist[None]), 0)
+
+ return input, targets
+
+######################################################################
+
+def save_image(name, input, targets, output = None):
+ input, targets = input.cpu(), targets.cpu()
+
+ weight = torch.tensor(
+ [
+ [ 1.0, 0.0, 0.0 ],
+ [ 1.0, 1.0, 0.0 ],
+ [ 0.0, 1.0, 0.0 ],
+ [ 0.0, 0.0, 1.0 ],
+ ] ).t()[:, :, None, None]
+
+ # img_trans = F.conv2d(input[:, 0:5], weight)
+ # img_trans = img_trans / img_trans.max()
+
+ img_trans = 1 / input[:, 8:9].expand(-1, 3, -1, -1)
+ img_trans = 1 - img_trans / img_trans.max()
+
+ img_start = input[:, 9:10].expand(-1, 3, -1, -1)
+ img_start = 1 - img_start / img_start.max()
+
+ img_walls = targets[:, 0:1].expand(-1, 3, -1, -1)
+ img_walls = 1 - img_walls / img_walls.max()
+
+ # img_pi = F.conv2d(targets[:, 2:6], weight)
+ # img_pi = img_pi / img_pi.max()
+
+ img_dist = targets[:, 1:2].expand(-1, 3, -1, -1)
+ img_dist = img_dist / img_dist.max()
+
+ img = (
+ img_start[:, None],
+ img_trans[:, None],
+ img_walls[:, None],
+ # img_pi[:, None],
+ img_dist[:, None],
+ )
+
+ if output is not None:
+ output = output.cpu()
+ img_walls = output[:, 0:1].expand(-1, 3, -1, -1)
+ img_walls = 1 - img_walls / img_walls.max()
+
+ # img_pi = F.conv2d(output[:, 2:6].mul(100).softmax(dim = 1), weight)
+ # img_pi = img_pi / img_pi.max() * output[:, 0:2].softmax(dim = 1)[:, 0:1]
+
+ img_dist = output[:, 1:2].expand(-1, 3, -1, -1)
+ img_dist = img_dist / img_dist.max()
+
+ img += (
+ img_walls[:, None],
+ img_dist[:, None],
+ # img_pi[:, None],
+ )
+
+ img_all = torch.cat(img, 1)
+
+ img_all = img_all.view(
+ img_all.size(0) * img_all.size(1),
+ img_all.size(2),
+ img_all.size(3),
+ img_all.size(4),
+ )
+
+ torchvision.utils.save_image(
+ img_all,
+ name,
+ padding = 1, pad_value = 0.5, nrow = len(img)
+ )
+
+ log_string(f'Wrote {name}')
+
+######################################################################
+
+class Net(nn.Module):
+ def __init__(self):
+ super().__init__()
+ nh = 128
+ self.conv1 = nn.Conv2d( 6, nh, kernel_size = 5, padding = 2)
+ self.conv2 = nn.Conv2d(nh, nh, kernel_size = 5, padding = 2)
+ self.conv3 = nn.Conv2d(nh, nh, kernel_size = 5, padding = 2)
+ self.conv4 = nn.Conv2d(nh, 2, kernel_size = 5, padding = 2)
+
+ def forward(self, x):
+ x = F.relu(self.conv1(x))
+ x = F.relu(self.conv2(x))
+ x = F.relu(self.conv3(x))
+ x = self.conv4(x)
+ return x
+
+######################################################################
+
+class ResNetBlock(nn.Module):
+ def __init__(self, nb_channels, kernel_size):
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(nb_channels, nb_channels,
+ kernel_size = kernel_size,
+ padding = (kernel_size - 1) // 2)
+
+ self.bn1 = nn.BatchNorm2d(nb_channels)
+
+ self.conv2 = nn.Conv2d(nb_channels, nb_channels,
+ kernel_size = kernel_size,
+ padding = (kernel_size - 1) // 2)
+
+ self.bn2 = nn.BatchNorm2d(nb_channels)
+
+ def forward(self, x):
+ y = F.relu(self.bn1(self.conv1(x)))
+ y = F.relu(x + self.bn2(self.conv2(y)))
+ return y
+
+class ResNet(nn.Module):
+
+ def __init__(self,
+ in_channels, out_channels,
+ nb_residual_blocks, nb_channels, kernel_size):
+ super().__init__()
+
+ self.pre_process = nn.Sequential(
+ nn.Conv2d(in_channels, nb_channels,
+ kernel_size = kernel_size,
+ padding = (kernel_size - 1) // 2),
+ nn.BatchNorm2d(nb_channels),
+ nn.ReLU(inplace = True),
+ )
+
+ blocks = []
+ for k in range(nb_residual_blocks):
+ blocks.append(ResNetBlock(nb_channels, kernel_size))
+
+ self.resnet_blocks = nn.Sequential(*blocks)
+
+ self.post_process = nn.Conv2d(nb_channels, out_channels, kernel_size = 1)
+
+ def forward(self, x):
+ x = self.pre_process(x)
+ x = self.resnet_blocks(x)
+ x = self.post_process(x)
+ return x
+
+######################################################################
+
+data_filename = 'path.dat'
+
+try:
+ input, targets = torch.load(data_filename)
+ log_string('Data loaded.')
+ assert input.size(0) == args.nb_for_train + args.nb_for_test and \
+ input.size(1) == 10 and \
+ input.size(2) == args.world_height and \
+ input.size(3) == args.world_width and \
+ \
+ targets.size(0) == args.nb_for_train + args.nb_for_test and \
+ targets.size(1) == 2 and \
+ targets.size(2) == args.world_height and \
+ targets.size(3) == args.world_width
+
+except FileNotFoundError:
+ log_string('Generating data.')
+
+ input, targets = create_maze_data(
+ nb = args.nb_for_train + args.nb_for_test,
+ h = args.world_height, w = args.world_width,
+ nb_walls = args.world_nb_walls,
+ traj_length = (100, 10000)
+ )
+
+ torch.save((input, targets), data_filename)
+
+except:
+ log_string('Error when loading data.')
+ exit(1)
+
+######################################################################
+
+for n in vars(args):
+ log_string(f'args.{n} {getattr(args, n)}')
+
+model = ResNet(
+ in_channels = 10, out_channels = 2,
+ nb_residual_blocks = args.nb_residual_blocks,
+ nb_channels = args.nb_channels,
+ kernel_size = args.kernel_size
+)
+
+criterion = nn.MSELoss()
+
+model.to(device)
+criterion.to(device)
+
+input, targets = input.to(device), targets.to(device)
+
+train_input, train_targets = input[:args.nb_for_train], targets[:args.nb_for_train]
+test_input, test_targets = input[args.nb_for_train:], targets[args.nb_for_train:]
+
+mu, std = train_input.mean(), train_input.std()
+train_input.sub_(mu).div_(std)
+test_input.sub_(mu).div_(std)
+
+######################################################################
+
+eta = ETA(args.nb_epochs)
+
+for e in range(args.nb_epochs):
+
+ if e < args.nb_epochs // 2:
+ lr = 1e-2
+ else:
+ lr = 1e-3
+
+ optimizer = torch.optim.Adam(model.parameters(), lr = lr)
+
+ acc_train_loss = 0.0
+
+ for input, targets in zip(train_input.split(args.batch_size),
+ train_targets.split(args.batch_size)):
+ output = model(input)
+
+ loss = criterion(output, targets)
+ acc_train_loss += loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ test_loss = 0.0
+
+ for input, targets in zip(test_input.split(args.batch_size),
+ test_targets.split(args.batch_size)):
+ output = model(input)
+ loss = criterion(output, targets)
+ test_loss += loss.item()
+
+ log_string(
+ f'{e} acc_train_loss {acc_train_loss / (args.nb_for_train / args.batch_size)} test_loss {test_loss / (args.nb_for_test / args.batch_size)} ETA {eta.eta(e+1)}'
+ )
+
+ # save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8]))
+ # save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8]))
+
+ save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8])[:, 0:2])
+ save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8])[:, 0:2])
+
+######################################################################