Initial commit
authorFrançois Fleuret <francois@fleuret.org>
Wed, 15 Feb 2023 19:15:06 +0000 (20:15 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 15 Feb 2023 19:15:06 +0000 (20:15 +0100)
path.py [new file with mode: 0755]

diff --git a/path.py b/path.py
new file mode 100755 (executable)
index 0000000..c26afed
--- /dev/null
+++ b/path.py
@@ -0,0 +1,483 @@
+#!/usr/bin/env python
+
+import sys, math, time, argparse
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+parser = argparse.ArgumentParser(
+    description='Path-planning as denoising.',
+    formatter_class = argparse.ArgumentDefaultsHelpFormatter
+)
+
+parser.add_argument('--nb_epochs',
+                    type = int, default = 25)
+
+parser.add_argument('--batch_size',
+                    type = int, default = 100)
+
+parser.add_argument('--nb_residual_blocks',
+                    type = int, default = 16)
+
+parser.add_argument('--nb_channels',
+                    type = int, default = 128)
+
+parser.add_argument('--kernel_size',
+                    type = int, default = 3)
+
+parser.add_argument('--nb_for_train',
+                    type = int, default = 100000)
+
+parser.add_argument('--nb_for_test',
+                    type = int, default = 10000)
+
+parser.add_argument('--world_height',
+                    type = int, default = 23)
+
+parser.add_argument('--world_width',
+                    type = int, default = 31)
+
+parser.add_argument('--world_nb_walls',
+                    type = int, default = 15)
+
+parser.add_argument('--seed',
+                    type = int, default = 0,
+                    help = 'Random seed (default 0, < 0 is no seeding)')
+
+######################################################################
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+    torch.manual_seed(args.seed)
+
+######################################################################
+
+label=''
+
+log_file = open(f'path_{label}train.log', 'w')
+
+######################################################################
+
+def log_string(s):
+    t = time.strftime('%Y%m%d-%H:%M:%S', time.localtime())
+    s = t + ' - ' + s
+    if log_file is not None:
+        log_file.write(s + '\n')
+        log_file.flush()
+
+    print(s)
+    sys.stdout.flush()
+
+######################################################################
+
+class ETA:
+    def __init__(self, n):
+        self.n = n
+        self.t0 = time.time()
+
+    def eta(self, k):
+        if k > 0:
+            t = time.time()
+            u = self.t0 + ((t - self.t0) * self.n) // k
+            return time.strftime('%Y%m%d-%H:%M:%S', time.localtime(u))
+        else:
+            return "n.a."
+
+######################################################################
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+log_string(f'device {device}')
+
+######################################################################
+
+def create_maze(h = 11, w = 15, nb_walls = 10):
+    a, k = 0, 0
+
+    while k < nb_walls:
+        while True:
+            if a == 0:
+                m = torch.zeros(h, w, dtype = torch.int64)
+                m[ 0,  :] = 1
+                m[-1,  :] = 1
+                m[ :,  0] = 1
+                m[ :, -1] = 1
+
+            r = torch.rand(4)
+
+            if r[0] <= 0.5:
+                i1, i2, j = int((r[1] * h).item()), int((r[2] * h).item()), int((r[3] * w).item())
+                i1, i2, j = i1 - i1%2, i2 - i2%2, j - j%2
+                i1, i2 = min(i1, i2), max(i1, i2)
+                if i2 - i1 > 1 and i2 - i1 <= h/2 and m[i1:i2+1, j].sum() <= 1:
+                    m[i1:i2+1, j] = 1
+                    break
+            else:
+                i, j1, j2 = int((r[1] * h).item()), int((r[2] * w).item()), int((r[3] * w).item())
+                i, j1, j2 = i - i%2, j1 - j1%2, j2 - j2%2
+                j1, j2 = min(j1, j2), max(j1, j2)
+                if j2 - j1 > 1 and j2 - j1 <= w/2 and m[i, j1:j2+1].sum() <= 1:
+                    m[i, j1:j2+1] = 1
+                    break
+            a += 1
+
+            if a > 10 * nb_walls: a, k = 0, 0
+
+        k += 1
+
+    return m
+
+######################################################################
+
+def random_free_position(walls):
+    p = torch.randperm(walls.numel())
+    k = p[walls.view(-1)[p] == 0][0].item()
+    return k//walls.size(1), k%walls.size(1)
+
+def create_transitions(walls, nb):
+    trans = walls.new_zeros((9,) + walls.size())
+    t = torch.randint(4, (nb,))
+    i, j = random_free_position(walls)
+
+    for k in range(t.size(0)):
+        di, dj = [ (0, 1), (1, 0), (0, -1), (-1, 0) ][t[k]]
+        ip, jp = i + di, j + dj
+        if ip < 0 or ip >= walls.size(0) or \
+           jp < 0 or jp >= walls.size(1) or \
+           walls[ip, jp] > 0:
+            trans[t[k] + 4, i, j] += 1
+        else:
+            trans[t[k], i, j] += 1
+            i, j = ip, jp
+
+    n = trans[0:8].sum(dim = 0, keepdim = True)
+    trans[8:9] = n
+    trans[0:8] = trans[0:8] / (n + (n == 0).long())
+
+    return trans
+
+######################################################################
+
+def compute_distance(walls, i, j):
+    max_length = walls.numel()
+    dist = torch.full_like(walls, max_length)
+
+    dist[i, j] = 0
+    pred_dist = torch.empty_like(dist)
+
+    while True:
+        pred_dist.copy_(dist)
+        d = torch.cat(
+            (
+                dist[None, 1:-1, 0:-2],
+                dist[None, 2:, 1:-1],
+                dist[None, 1:-1, 2:],
+                dist[None, 0:-2, 1:-1]
+            ),
+        0).min(dim = 0)[0] + 1
+
+        dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
+        dist = walls * max_length + (1 - walls) * dist
+
+        if dist.equal(pred_dist): return dist * (1 - walls)
+
+######################################################################
+
+def compute_policy(walls, i, j):
+    distance = compute_distance(walls, i, j)
+    distance = distance + walls.numel() * walls
+
+    value = distance.new_full((4,) + distance.size(), walls.numel())
+    value[0,  :  , 1:  ] = distance[ :  ,  :-1]
+    value[1,  :  ,  :-1] = distance[ :  , 1:  ]
+    value[2, 1:  ,  :  ] = distance[ :-1,  :  ]
+    value[3,  :-1,  :  ] = distance[1:  ,  :  ]
+
+    proba = (value.min(dim = 0)[0][None] == value).float()
+    proba = proba / proba.sum(dim = 0)[None]
+    proba = proba * (1 - walls)
+
+    return proba
+
+######################################################################
+
+def create_maze_data(nb, h = 11, w = 17, nb_walls = 8, traj_length = 50):
+    input = torch.empty(nb, 10, h, w)
+    targets = torch.empty(nb, 2, h, w)
+
+    if type(traj_length) == tuple:
+        l = (torch.rand(nb) * (traj_length[1] - traj_length[0]) + traj_length[0]).long()
+    else:
+        l = torch.full((nb,), traj_length).long()
+
+    eta = ETA(nb)
+
+    for n in range(nb):
+        if n%(max(10, nb//1000)) == 0:
+            log_string(f'{(100 * n)/nb:.02f}% ETA {eta.eta(n+1)}')
+
+        walls = create_maze(h, w, nb_walls)
+        trans = create_transitions(walls, l[n])
+
+        i, j = random_free_position(walls)
+        start = walls.new_zeros(walls.size())
+        start[i, j] = 1
+        dist = compute_distance(walls, i, j)
+
+        input[n] = torch.cat((trans, start[None]), 0)
+        targets[n] = torch.cat((walls[None], dist[None]), 0)
+
+    return input, targets
+
+######################################################################
+
+def save_image(name, input, targets, output = None):
+    input, targets = input.cpu(), targets.cpu()
+
+    weight = torch.tensor(
+        [
+            [ 1.0, 0.0, 0.0 ],
+            [ 1.0, 1.0, 0.0 ],
+            [ 0.0, 1.0, 0.0 ],
+            [ 0.0, 0.0, 1.0 ],
+        ] ).t()[:, :, None, None]
+
+    # img_trans = F.conv2d(input[:, 0:5], weight)
+    # img_trans = img_trans / img_trans.max()
+
+    img_trans = 1 / input[:, 8:9].expand(-1, 3, -1, -1)
+    img_trans = 1 - img_trans / img_trans.max()
+
+    img_start = input[:, 9:10].expand(-1, 3, -1, -1)
+    img_start = 1 - img_start / img_start.max()
+
+    img_walls = targets[:, 0:1].expand(-1, 3, -1, -1)
+    img_walls = 1 - img_walls / img_walls.max()
+
+    # img_pi = F.conv2d(targets[:, 2:6], weight)
+    # img_pi = img_pi / img_pi.max()
+
+    img_dist = targets[:, 1:2].expand(-1, 3, -1, -1)
+    img_dist = img_dist / img_dist.max()
+
+    img = (
+        img_start[:, None],
+        img_trans[:, None],
+        img_walls[:, None],
+        # img_pi[:, None],
+        img_dist[:, None],
+        )
+
+    if output is not None:
+        output = output.cpu()
+        img_walls = output[:, 0:1].expand(-1, 3, -1, -1)
+        img_walls = 1 - img_walls / img_walls.max()
+
+        # img_pi = F.conv2d(output[:, 2:6].mul(100).softmax(dim = 1), weight)
+        # img_pi = img_pi / img_pi.max() * output[:, 0:2].softmax(dim = 1)[:, 0:1]
+
+        img_dist = output[:, 1:2].expand(-1, 3, -1, -1)
+        img_dist = img_dist / img_dist.max()
+
+        img += (
+            img_walls[:, None],
+            img_dist[:, None],
+            # img_pi[:, None],
+        )
+
+    img_all = torch.cat(img, 1)
+
+    img_all = img_all.view(
+        img_all.size(0) * img_all.size(1),
+        img_all.size(2),
+        img_all.size(3),
+        img_all.size(4),
+    )
+
+    torchvision.utils.save_image(
+        img_all,
+        name,
+        padding = 1, pad_value = 0.5, nrow = len(img)
+    )
+
+    log_string(f'Wrote {name}')
+
+######################################################################
+
+class Net(nn.Module):
+    def __init__(self):
+        super().__init__()
+        nh = 128
+        self.conv1 = nn.Conv2d( 6, nh, kernel_size = 5, padding = 2)
+        self.conv2 = nn.Conv2d(nh, nh, kernel_size = 5, padding = 2)
+        self.conv3 = nn.Conv2d(nh, nh, kernel_size = 5, padding = 2)
+        self.conv4 = nn.Conv2d(nh,  2, kernel_size = 5, padding = 2)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        x = F.relu(self.conv2(x))
+        x = F.relu(self.conv3(x))
+        x = self.conv4(x)
+        return x
+
+######################################################################
+
+class ResNetBlock(nn.Module):
+    def __init__(self, nb_channels, kernel_size):
+        super().__init__()
+
+        self.conv1 = nn.Conv2d(nb_channels, nb_channels,
+                               kernel_size = kernel_size,
+                               padding = (kernel_size - 1) // 2)
+
+        self.bn1 = nn.BatchNorm2d(nb_channels)
+
+        self.conv2 = nn.Conv2d(nb_channels, nb_channels,
+                               kernel_size = kernel_size,
+                               padding = (kernel_size - 1) // 2)
+
+        self.bn2 = nn.BatchNorm2d(nb_channels)
+
+    def forward(self, x):
+        y = F.relu(self.bn1(self.conv1(x)))
+        y = F.relu(x + self.bn2(self.conv2(y)))
+        return y
+
+class ResNet(nn.Module):
+
+    def __init__(self,
+                 in_channels, out_channels,
+                 nb_residual_blocks, nb_channels, kernel_size):
+        super().__init__()
+
+        self.pre_process = nn.Sequential(
+            nn.Conv2d(in_channels, nb_channels,
+                      kernel_size = kernel_size,
+                      padding = (kernel_size - 1) // 2),
+            nn.BatchNorm2d(nb_channels),
+            nn.ReLU(inplace = True),
+        )
+
+        blocks = []
+        for k in range(nb_residual_blocks):
+            blocks.append(ResNetBlock(nb_channels, kernel_size))
+
+        self.resnet_blocks = nn.Sequential(*blocks)
+
+        self.post_process = nn.Conv2d(nb_channels, out_channels, kernel_size = 1)
+
+    def forward(self, x):
+        x = self.pre_process(x)
+        x = self.resnet_blocks(x)
+        x = self.post_process(x)
+        return x
+
+######################################################################
+
+data_filename = 'path.dat'
+
+try:
+    input, targets = torch.load(data_filename)
+    log_string('Data loaded.')
+    assert input.size(0) == args.nb_for_train + args.nb_for_test and \
+           input.size(1) == 10 and \
+           input.size(2) == args.world_height and \
+           input.size(3) == args.world_width and \
+           \
+           targets.size(0) == args.nb_for_train + args.nb_for_test and \
+           targets.size(1) == 2 and \
+           targets.size(2) == args.world_height and \
+           targets.size(3) == args.world_width
+
+except FileNotFoundError:
+    log_string('Generating data.')
+
+    input, targets = create_maze_data(
+        nb = args.nb_for_train + args.nb_for_test,
+        h = args.world_height, w = args.world_width,
+        nb_walls = args.world_nb_walls,
+        traj_length = (100, 10000)
+    )
+
+    torch.save((input, targets), data_filename)
+
+except:
+    log_string('Error when loading data.')
+    exit(1)
+
+######################################################################
+
+for n in vars(args):
+    log_string(f'args.{n} {getattr(args, n)}')
+
+model = ResNet(
+    in_channels = 10, out_channels = 2,
+    nb_residual_blocks = args.nb_residual_blocks,
+    nb_channels = args.nb_channels,
+    kernel_size = args.kernel_size
+)
+
+criterion = nn.MSELoss()
+
+model.to(device)
+criterion.to(device)
+
+input, targets = input.to(device), targets.to(device)
+
+train_input, train_targets = input[:args.nb_for_train], targets[:args.nb_for_train]
+test_input, test_targets   = input[args.nb_for_train:], targets[args.nb_for_train:]
+
+mu, std = train_input.mean(), train_input.std()
+train_input.sub_(mu).div_(std)
+test_input.sub_(mu).div_(std)
+
+######################################################################
+
+eta = ETA(args.nb_epochs)
+
+for e in range(args.nb_epochs):
+
+    if e < args.nb_epochs // 2:
+        lr = 1e-2
+    else:
+        lr = 1e-3
+
+    optimizer = torch.optim.Adam(model.parameters(), lr = lr)
+
+    acc_train_loss = 0.0
+
+    for input, targets in zip(train_input.split(args.batch_size),
+                              train_targets.split(args.batch_size)):
+        output = model(input)
+
+        loss = criterion(output, targets)
+        acc_train_loss += loss.item()
+
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+    test_loss = 0.0
+
+    for input, targets in zip(test_input.split(args.batch_size),
+                              test_targets.split(args.batch_size)):
+        output = model(input)
+        loss = criterion(output, targets)
+        test_loss += loss.item()
+
+    log_string(
+        f'{e} acc_train_loss {acc_train_loss / (args.nb_for_train / args.batch_size)} test_loss {test_loss / (args.nb_for_test / args.batch_size)} ETA {eta.eta(e+1)}'
+    )
+
+    # save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8]))
+    # save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8]))
+
+    save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8])[:, 0:2])
+    save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8])[:, 0:2])
+
+######################################################################