From 6df2a3a958554b930e0da6f7884c96352780df6f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 15 Feb 2023 20:15:06 +0100 Subject: [PATCH] Initial commit --- path.py | 483 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 483 insertions(+) create mode 100755 path.py diff --git a/path.py b/path.py new file mode 100755 index 0000000..c26afed --- /dev/null +++ b/path.py @@ -0,0 +1,483 @@ +#!/usr/bin/env python + +import sys, math, time, argparse + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + +parser = argparse.ArgumentParser( + description='Path-planning as denoising.', + formatter_class = argparse.ArgumentDefaultsHelpFormatter +) + +parser.add_argument('--nb_epochs', + type = int, default = 25) + +parser.add_argument('--batch_size', + type = int, default = 100) + +parser.add_argument('--nb_residual_blocks', + type = int, default = 16) + +parser.add_argument('--nb_channels', + type = int, default = 128) + +parser.add_argument('--kernel_size', + type = int, default = 3) + +parser.add_argument('--nb_for_train', + type = int, default = 100000) + +parser.add_argument('--nb_for_test', + type = int, default = 10000) + +parser.add_argument('--world_height', + type = int, default = 23) + +parser.add_argument('--world_width', + type = int, default = 31) + +parser.add_argument('--world_nb_walls', + type = int, default = 15) + +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed (default 0, < 0 is no seeding)') + +###################################################################### + +args = parser.parse_args() + +if args.seed >= 0: + torch.manual_seed(args.seed) + +###################################################################### + +label='' + +log_file = open(f'path_{label}train.log', 'w') + +###################################################################### + +def log_string(s): + t = time.strftime('%Y%m%d-%H:%M:%S', time.localtime()) + s = t + ' - ' + s + if log_file is not None: + log_file.write(s + '\n') + log_file.flush() + + print(s) + sys.stdout.flush() + +###################################################################### + +class ETA: + def __init__(self, n): + self.n = n + self.t0 = time.time() + + def eta(self, k): + if k > 0: + t = time.time() + u = self.t0 + ((t - self.t0) * self.n) // k + return time.strftime('%Y%m%d-%H:%M:%S', time.localtime(u)) + else: + return "n.a." + +###################################################################### + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +log_string(f'device {device}') + +###################################################################### + +def create_maze(h = 11, w = 15, nb_walls = 10): + a, k = 0, 0 + + while k < nb_walls: + while True: + if a == 0: + m = torch.zeros(h, w, dtype = torch.int64) + m[ 0, :] = 1 + m[-1, :] = 1 + m[ :, 0] = 1 + m[ :, -1] = 1 + + r = torch.rand(4) + + if r[0] <= 0.5: + i1, i2, j = int((r[1] * h).item()), int((r[2] * h).item()), int((r[3] * w).item()) + i1, i2, j = i1 - i1%2, i2 - i2%2, j - j%2 + i1, i2 = min(i1, i2), max(i1, i2) + if i2 - i1 > 1 and i2 - i1 <= h/2 and m[i1:i2+1, j].sum() <= 1: + m[i1:i2+1, j] = 1 + break + else: + i, j1, j2 = int((r[1] * h).item()), int((r[2] * w).item()), int((r[3] * w).item()) + i, j1, j2 = i - i%2, j1 - j1%2, j2 - j2%2 + j1, j2 = min(j1, j2), max(j1, j2) + if j2 - j1 > 1 and j2 - j1 <= w/2 and m[i, j1:j2+1].sum() <= 1: + m[i, j1:j2+1] = 1 + break + a += 1 + + if a > 10 * nb_walls: a, k = 0, 0 + + k += 1 + + return m + +###################################################################### + +def random_free_position(walls): + p = torch.randperm(walls.numel()) + k = p[walls.view(-1)[p] == 0][0].item() + return k//walls.size(1), k%walls.size(1) + +def create_transitions(walls, nb): + trans = walls.new_zeros((9,) + walls.size()) + t = torch.randint(4, (nb,)) + i, j = random_free_position(walls) + + for k in range(t.size(0)): + di, dj = [ (0, 1), (1, 0), (0, -1), (-1, 0) ][t[k]] + ip, jp = i + di, j + dj + if ip < 0 or ip >= walls.size(0) or \ + jp < 0 or jp >= walls.size(1) or \ + walls[ip, jp] > 0: + trans[t[k] + 4, i, j] += 1 + else: + trans[t[k], i, j] += 1 + i, j = ip, jp + + n = trans[0:8].sum(dim = 0, keepdim = True) + trans[8:9] = n + trans[0:8] = trans[0:8] / (n + (n == 0).long()) + + return trans + +###################################################################### + +def compute_distance(walls, i, j): + max_length = walls.numel() + dist = torch.full_like(walls, max_length) + + dist[i, j] = 0 + pred_dist = torch.empty_like(dist) + + while True: + pred_dist.copy_(dist) + d = torch.cat( + ( + dist[None, 1:-1, 0:-2], + dist[None, 2:, 1:-1], + dist[None, 1:-1, 2:], + dist[None, 0:-2, 1:-1] + ), + 0).min(dim = 0)[0] + 1 + + dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d) + dist = walls * max_length + (1 - walls) * dist + + if dist.equal(pred_dist): return dist * (1 - walls) + +###################################################################### + +def compute_policy(walls, i, j): + distance = compute_distance(walls, i, j) + distance = distance + walls.numel() * walls + + value = distance.new_full((4,) + distance.size(), walls.numel()) + value[0, : , 1: ] = distance[ : , :-1] + value[1, : , :-1] = distance[ : , 1: ] + value[2, 1: , : ] = distance[ :-1, : ] + value[3, :-1, : ] = distance[1: , : ] + + proba = (value.min(dim = 0)[0][None] == value).float() + proba = proba / proba.sum(dim = 0)[None] + proba = proba * (1 - walls) + + return proba + +###################################################################### + +def create_maze_data(nb, h = 11, w = 17, nb_walls = 8, traj_length = 50): + input = torch.empty(nb, 10, h, w) + targets = torch.empty(nb, 2, h, w) + + if type(traj_length) == tuple: + l = (torch.rand(nb) * (traj_length[1] - traj_length[0]) + traj_length[0]).long() + else: + l = torch.full((nb,), traj_length).long() + + eta = ETA(nb) + + for n in range(nb): + if n%(max(10, nb//1000)) == 0: + log_string(f'{(100 * n)/nb:.02f}% ETA {eta.eta(n+1)}') + + walls = create_maze(h, w, nb_walls) + trans = create_transitions(walls, l[n]) + + i, j = random_free_position(walls) + start = walls.new_zeros(walls.size()) + start[i, j] = 1 + dist = compute_distance(walls, i, j) + + input[n] = torch.cat((trans, start[None]), 0) + targets[n] = torch.cat((walls[None], dist[None]), 0) + + return input, targets + +###################################################################### + +def save_image(name, input, targets, output = None): + input, targets = input.cpu(), targets.cpu() + + weight = torch.tensor( + [ + [ 1.0, 0.0, 0.0 ], + [ 1.0, 1.0, 0.0 ], + [ 0.0, 1.0, 0.0 ], + [ 0.0, 0.0, 1.0 ], + ] ).t()[:, :, None, None] + + # img_trans = F.conv2d(input[:, 0:5], weight) + # img_trans = img_trans / img_trans.max() + + img_trans = 1 / input[:, 8:9].expand(-1, 3, -1, -1) + img_trans = 1 - img_trans / img_trans.max() + + img_start = input[:, 9:10].expand(-1, 3, -1, -1) + img_start = 1 - img_start / img_start.max() + + img_walls = targets[:, 0:1].expand(-1, 3, -1, -1) + img_walls = 1 - img_walls / img_walls.max() + + # img_pi = F.conv2d(targets[:, 2:6], weight) + # img_pi = img_pi / img_pi.max() + + img_dist = targets[:, 1:2].expand(-1, 3, -1, -1) + img_dist = img_dist / img_dist.max() + + img = ( + img_start[:, None], + img_trans[:, None], + img_walls[:, None], + # img_pi[:, None], + img_dist[:, None], + ) + + if output is not None: + output = output.cpu() + img_walls = output[:, 0:1].expand(-1, 3, -1, -1) + img_walls = 1 - img_walls / img_walls.max() + + # img_pi = F.conv2d(output[:, 2:6].mul(100).softmax(dim = 1), weight) + # img_pi = img_pi / img_pi.max() * output[:, 0:2].softmax(dim = 1)[:, 0:1] + + img_dist = output[:, 1:2].expand(-1, 3, -1, -1) + img_dist = img_dist / img_dist.max() + + img += ( + img_walls[:, None], + img_dist[:, None], + # img_pi[:, None], + ) + + img_all = torch.cat(img, 1) + + img_all = img_all.view( + img_all.size(0) * img_all.size(1), + img_all.size(2), + img_all.size(3), + img_all.size(4), + ) + + torchvision.utils.save_image( + img_all, + name, + padding = 1, pad_value = 0.5, nrow = len(img) + ) + + log_string(f'Wrote {name}') + +###################################################################### + +class Net(nn.Module): + def __init__(self): + super().__init__() + nh = 128 + self.conv1 = nn.Conv2d( 6, nh, kernel_size = 5, padding = 2) + self.conv2 = nn.Conv2d(nh, nh, kernel_size = 5, padding = 2) + self.conv3 = nn.Conv2d(nh, nh, kernel_size = 5, padding = 2) + self.conv4 = nn.Conv2d(nh, 2, kernel_size = 5, padding = 2) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + x = self.conv4(x) + return x + +###################################################################### + +class ResNetBlock(nn.Module): + def __init__(self, nb_channels, kernel_size): + super().__init__() + + self.conv1 = nn.Conv2d(nb_channels, nb_channels, + kernel_size = kernel_size, + padding = (kernel_size - 1) // 2) + + self.bn1 = nn.BatchNorm2d(nb_channels) + + self.conv2 = nn.Conv2d(nb_channels, nb_channels, + kernel_size = kernel_size, + padding = (kernel_size - 1) // 2) + + self.bn2 = nn.BatchNorm2d(nb_channels) + + def forward(self, x): + y = F.relu(self.bn1(self.conv1(x))) + y = F.relu(x + self.bn2(self.conv2(y))) + return y + +class ResNet(nn.Module): + + def __init__(self, + in_channels, out_channels, + nb_residual_blocks, nb_channels, kernel_size): + super().__init__() + + self.pre_process = nn.Sequential( + nn.Conv2d(in_channels, nb_channels, + kernel_size = kernel_size, + padding = (kernel_size - 1) // 2), + nn.BatchNorm2d(nb_channels), + nn.ReLU(inplace = True), + ) + + blocks = [] + for k in range(nb_residual_blocks): + blocks.append(ResNetBlock(nb_channels, kernel_size)) + + self.resnet_blocks = nn.Sequential(*blocks) + + self.post_process = nn.Conv2d(nb_channels, out_channels, kernel_size = 1) + + def forward(self, x): + x = self.pre_process(x) + x = self.resnet_blocks(x) + x = self.post_process(x) + return x + +###################################################################### + +data_filename = 'path.dat' + +try: + input, targets = torch.load(data_filename) + log_string('Data loaded.') + assert input.size(0) == args.nb_for_train + args.nb_for_test and \ + input.size(1) == 10 and \ + input.size(2) == args.world_height and \ + input.size(3) == args.world_width and \ + \ + targets.size(0) == args.nb_for_train + args.nb_for_test and \ + targets.size(1) == 2 and \ + targets.size(2) == args.world_height and \ + targets.size(3) == args.world_width + +except FileNotFoundError: + log_string('Generating data.') + + input, targets = create_maze_data( + nb = args.nb_for_train + args.nb_for_test, + h = args.world_height, w = args.world_width, + nb_walls = args.world_nb_walls, + traj_length = (100, 10000) + ) + + torch.save((input, targets), data_filename) + +except: + log_string('Error when loading data.') + exit(1) + +###################################################################### + +for n in vars(args): + log_string(f'args.{n} {getattr(args, n)}') + +model = ResNet( + in_channels = 10, out_channels = 2, + nb_residual_blocks = args.nb_residual_blocks, + nb_channels = args.nb_channels, + kernel_size = args.kernel_size +) + +criterion = nn.MSELoss() + +model.to(device) +criterion.to(device) + +input, targets = input.to(device), targets.to(device) + +train_input, train_targets = input[:args.nb_for_train], targets[:args.nb_for_train] +test_input, test_targets = input[args.nb_for_train:], targets[args.nb_for_train:] + +mu, std = train_input.mean(), train_input.std() +train_input.sub_(mu).div_(std) +test_input.sub_(mu).div_(std) + +###################################################################### + +eta = ETA(args.nb_epochs) + +for e in range(args.nb_epochs): + + if e < args.nb_epochs // 2: + lr = 1e-2 + else: + lr = 1e-3 + + optimizer = torch.optim.Adam(model.parameters(), lr = lr) + + acc_train_loss = 0.0 + + for input, targets in zip(train_input.split(args.batch_size), + train_targets.split(args.batch_size)): + output = model(input) + + loss = criterion(output, targets) + acc_train_loss += loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + test_loss = 0.0 + + for input, targets in zip(test_input.split(args.batch_size), + test_targets.split(args.batch_size)): + output = model(input) + loss = criterion(output, targets) + test_loss += loss.item() + + log_string( + f'{e} acc_train_loss {acc_train_loss / (args.nb_for_train / args.batch_size)} test_loss {test_loss / (args.nb_for_test / args.batch_size)} ETA {eta.eta(e+1)}' + ) + + # save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8])) + # save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8])) + + save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8])[:, 0:2]) + save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8])[:, 0:2]) + +###################################################################### -- 2.39.5