#!/usr/bin/env python
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
import sys, math, time, argparse
import torch, torchvision
######################################################################
parser = argparse.ArgumentParser(
- description='Path-planning as denoising.',
- formatter_class = argparse.ArgumentDefaultsHelpFormatter
+ description="Path-planning as denoising.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
-parser.add_argument('--nb_epochs',
- type = int, default = 25)
+parser.add_argument("--nb_epochs", type=int, default=25)
-parser.add_argument('--batch_size',
- type = int, default = 100)
+parser.add_argument("--batch_size", type=int, default=100)
-parser.add_argument('--nb_residual_blocks',
- type = int, default = 16)
+parser.add_argument("--nb_residual_blocks", type=int, default=16)
-parser.add_argument('--nb_channels',
- type = int, default = 128)
+parser.add_argument("--nb_channels", type=int, default=128)
-parser.add_argument('--kernel_size',
- type = int, default = 3)
+parser.add_argument("--kernel_size", type=int, default=3)
-parser.add_argument('--nb_for_train',
- type = int, default = 100000)
+parser.add_argument("--nb_for_train", type=int, default=100000)
-parser.add_argument('--nb_for_test',
- type = int, default = 10000)
+parser.add_argument("--nb_for_test", type=int, default=10000)
-parser.add_argument('--world_height',
- type = int, default = 23)
+parser.add_argument("--world_height", type=int, default=23)
-parser.add_argument('--world_width',
- type = int, default = 31)
+parser.add_argument("--world_width", type=int, default=31)
-parser.add_argument('--world_nb_walls',
- type = int, default = 15)
+parser.add_argument("--world_nb_walls", type=int, default=15)
-parser.add_argument('--seed',
- type = int, default = 0,
- help = 'Random seed (default 0, < 0 is no seeding)')
+parser.add_argument(
+ "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
+)
######################################################################
######################################################################
-label=''
+label = ""
-log_file = open(f'path_{label}train.log', 'w')
+log_file = open(f"path_{label}train.log", "w")
######################################################################
+
def log_string(s):
- t = time.strftime('%Y%m%d-%H:%M:%S', time.localtime())
- s = t + ' - ' + s
+ t = time.strftime("%Y%m%d-%H:%M:%S", time.localtime())
+ s = t + " - " + s
if log_file is not None:
- log_file.write(s + '\n')
+ log_file.write(s + "\n")
log_file.flush()
print(s)
sys.stdout.flush()
+
######################################################################
+
class ETA:
def __init__(self, n):
self.n = n
if k > 0:
t = time.time()
u = self.t0 + ((t - self.t0) * self.n) // k
- return time.strftime('%Y%m%d-%H:%M:%S', time.localtime(u))
+ return time.strftime("%Y%m%d-%H:%M:%S", time.localtime(u))
else:
return "n.a."
+
######################################################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-log_string(f'device {device}')
+log_string(f"device {device}")
######################################################################
-def create_maze(h = 11, w = 15, nb_walls = 10):
+
+def create_maze(h=11, w=15, nb_walls=10):
a, k = 0, 0
while k < nb_walls:
while True:
if a == 0:
- m = torch.zeros(h, w, dtype = torch.int64)
- m[ 0, :] = 1
- m[-1, :] = 1
- m[ :, 0] = 1
- m[ :, -1] = 1
+ m = torch.zeros(h, w, dtype=torch.int64)
+ m[0, :] = 1
+ m[-1, :] = 1
+ m[:, 0] = 1
+ m[:, -1] = 1
r = torch.rand(4)
if r[0] <= 0.5:
- i1, i2, j = int((r[1] * h).item()), int((r[2] * h).item()), int((r[3] * w).item())
- i1, i2, j = i1 - i1%2, i2 - i2%2, j - j%2
+ i1, i2, j = (
+ int((r[1] * h).item()),
+ int((r[2] * h).item()),
+ int((r[3] * w).item()),
+ )
+ i1, i2, j = i1 - i1 % 2, i2 - i2 % 2, j - j % 2
i1, i2 = min(i1, i2), max(i1, i2)
- if i2 - i1 > 1 and i2 - i1 <= h/2 and m[i1:i2+1, j].sum() <= 1:
- m[i1:i2+1, j] = 1
+ if i2 - i1 > 1 and i2 - i1 <= h / 2 and m[i1 : i2 + 1, j].sum() <= 1:
+ m[i1 : i2 + 1, j] = 1
break
else:
- i, j1, j2 = int((r[1] * h).item()), int((r[2] * w).item()), int((r[3] * w).item())
- i, j1, j2 = i - i%2, j1 - j1%2, j2 - j2%2
+ i, j1, j2 = (
+ int((r[1] * h).item()),
+ int((r[2] * w).item()),
+ int((r[3] * w).item()),
+ )
+ i, j1, j2 = i - i % 2, j1 - j1 % 2, j2 - j2 % 2
j1, j2 = min(j1, j2), max(j1, j2)
- if j2 - j1 > 1 and j2 - j1 <= w/2 and m[i, j1:j2+1].sum() <= 1:
- m[i, j1:j2+1] = 1
+ if j2 - j1 > 1 and j2 - j1 <= w / 2 and m[i, j1 : j2 + 1].sum() <= 1:
+ m[i, j1 : j2 + 1] = 1
break
a += 1
- if a > 10 * nb_walls: a, k = 0, 0
+ if a > 10 * nb_walls:
+ a, k = 0, 0
k += 1
return m
+
######################################################################
+
def random_free_position(walls):
p = torch.randperm(walls.numel())
k = p[walls.view(-1)[p] == 0][0].item()
- return k//walls.size(1), k%walls.size(1)
+ return k // walls.size(1), k % walls.size(1)
+
def create_transitions(walls, nb):
trans = walls.new_zeros((9,) + walls.size())
i, j = random_free_position(walls)
for k in range(t.size(0)):
- di, dj = [ (0, 1), (1, 0), (0, -1), (-1, 0) ][t[k]]
+ di, dj = [(0, 1), (1, 0), (0, -1), (-1, 0)][t[k]]
ip, jp = i + di, j + dj
- if ip < 0 or ip >= walls.size(0) or \
- jp < 0 or jp >= walls.size(1) or \
- walls[ip, jp] > 0:
+ if (
+ ip < 0
+ or ip >= walls.size(0)
+ or jp < 0
+ or jp >= walls.size(1)
+ or walls[ip, jp] > 0
+ ):
trans[t[k] + 4, i, j] += 1
else:
trans[t[k], i, j] += 1
i, j = ip, jp
- n = trans[0:8].sum(dim = 0, keepdim = True)
+ n = trans[0:8].sum(dim=0, keepdim=True)
trans[8:9] = n
trans[0:8] = trans[0:8] / (n + (n == 0).long())
return trans
+
######################################################################
+
def compute_distance(walls, i, j):
max_length = walls.numel()
dist = torch.full_like(walls, max_length)
while True:
pred_dist.copy_(dist)
- d = torch.cat(
- (
- dist[None, 1:-1, 0:-2],
- dist[None, 2:, 1:-1],
- dist[None, 1:-1, 2:],
- dist[None, 0:-2, 1:-1]
- ),
- 0).min(dim = 0)[0] + 1
+ d = (
+ torch.cat(
+ (
+ dist[None, 1:-1, 0:-2],
+ dist[None, 2:, 1:-1],
+ dist[None, 1:-1, 2:],
+ dist[None, 0:-2, 1:-1],
+ ),
+ 0,
+ ).min(dim=0)[0]
+ + 1
+ )
dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
dist = walls * max_length + (1 - walls) * dist
- if dist.equal(pred_dist): return dist * (1 - walls)
+ if dist.equal(pred_dist):
+ return dist * (1 - walls)
+
######################################################################
+
def compute_policy(walls, i, j):
distance = compute_distance(walls, i, j)
distance = distance + walls.numel() * walls
value = distance.new_full((4,) + distance.size(), walls.numel())
- value[0, : , 1: ] = distance[ : , :-1]
- value[1, : , :-1] = distance[ : , 1: ]
- value[2, 1: , : ] = distance[ :-1, : ]
- value[3, :-1, : ] = distance[1: , : ]
+ value[0, :, 1:] = distance[:, :-1]
+ value[1, :, :-1] = distance[:, 1:]
+ value[2, 1:, :] = distance[:-1, :]
+ value[3, :-1, :] = distance[1:, :]
- proba = (value.min(dim = 0)[0][None] == value).float()
- proba = proba / proba.sum(dim = 0)[None]
+ proba = (value.min(dim=0)[0][None] == value).float()
+ proba = proba / proba.sum(dim=0)[None]
proba = proba * (1 - walls)
return proba
+
######################################################################
-def create_maze_data(nb, h = 11, w = 17, nb_walls = 8, traj_length = 50):
+
+def create_maze_data(nb, h=11, w=17, nb_walls=8, traj_length=50):
input = torch.empty(nb, 10, h, w)
targets = torch.empty(nb, 2, h, w)
eta = ETA(nb)
for n in range(nb):
- if n%(max(10, nb//1000)) == 0:
- log_string(f'{(100 * n)/nb:.02f}% ETA {eta.eta(n+1)}')
+ if n % (max(10, nb // 1000)) == 0:
+ log_string(f"{(100 * n)/nb:.02f}% ETA {eta.eta(n+1)}")
walls = create_maze(h, w, nb_walls)
trans = create_transitions(walls, l[n])
return input, targets
+
######################################################################
-def save_image(name, input, targets, output = None):
+
+def save_image(name, input, targets, output=None):
input, targets = input.cpu(), targets.cpu()
weight = torch.tensor(
[
- [ 1.0, 0.0, 0.0 ],
- [ 1.0, 1.0, 0.0 ],
- [ 0.0, 1.0, 0.0 ],
- [ 0.0, 0.0, 1.0 ],
- ] ).t()[:, :, None, None]
+ [1.0, 0.0, 0.0],
+ [1.0, 1.0, 0.0],
+ [0.0, 1.0, 0.0],
+ [0.0, 0.0, 1.0],
+ ]
+ ).t()[:, :, None, None]
# img_trans = F.conv2d(input[:, 0:5], weight)
# img_trans = img_trans / img_trans.max()
img_walls[:, None],
# img_pi[:, None],
img_dist[:, None],
- )
+ )
if output is not None:
output = output.cpu()
img_all.size(4),
)
- torchvision.utils.save_image(
- img_all,
- name,
- padding = 1, pad_value = 0.5, nrow = len(img)
- )
+ torchvision.utils.save_image(img_all, name, padding=1, pad_value=0.5, nrow=len(img))
+
+ log_string(f"Wrote {name}")
- log_string(f'Wrote {name}')
######################################################################
+
class Net(nn.Module):
def __init__(self):
super().__init__()
nh = 128
- self.conv1 = nn.Conv2d( 6, nh, kernel_size = 5, padding = 2)
- self.conv2 = nn.Conv2d(nh, nh, kernel_size = 5, padding = 2)
- self.conv3 = nn.Conv2d(nh, nh, kernel_size = 5, padding = 2)
- self.conv4 = nn.Conv2d(nh, 2, kernel_size = 5, padding = 2)
+ self.conv1 = nn.Conv2d(6, nh, kernel_size=5, padding=2)
+ self.conv2 = nn.Conv2d(nh, nh, kernel_size=5, padding=2)
+ self.conv3 = nn.Conv2d(nh, nh, kernel_size=5, padding=2)
+ self.conv4 = nn.Conv2d(nh, 2, kernel_size=5, padding=2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.conv4(x)
return x
+
######################################################################
+
class ResNetBlock(nn.Module):
def __init__(self, nb_channels, kernel_size):
super().__init__()
- self.conv1 = nn.Conv2d(nb_channels, nb_channels,
- kernel_size = kernel_size,
- padding = (kernel_size - 1) // 2)
+ self.conv1 = nn.Conv2d(
+ nb_channels,
+ nb_channels,
+ kernel_size=kernel_size,
+ padding=(kernel_size - 1) // 2,
+ )
self.bn1 = nn.BatchNorm2d(nb_channels)
- self.conv2 = nn.Conv2d(nb_channels, nb_channels,
- kernel_size = kernel_size,
- padding = (kernel_size - 1) // 2)
+ self.conv2 = nn.Conv2d(
+ nb_channels,
+ nb_channels,
+ kernel_size=kernel_size,
+ padding=(kernel_size - 1) // 2,
+ )
self.bn2 = nn.BatchNorm2d(nb_channels)
y = F.relu(x + self.bn2(self.conv2(y)))
return y
-class ResNet(nn.Module):
- def __init__(self,
- in_channels, out_channels,
- nb_residual_blocks, nb_channels, kernel_size):
+class ResNet(nn.Module):
+ def __init__(
+ self, in_channels, out_channels, nb_residual_blocks, nb_channels, kernel_size
+ ):
super().__init__()
self.pre_process = nn.Sequential(
- nn.Conv2d(in_channels, nb_channels,
- kernel_size = kernel_size,
- padding = (kernel_size - 1) // 2),
+ nn.Conv2d(
+ in_channels,
+ nb_channels,
+ kernel_size=kernel_size,
+ padding=(kernel_size - 1) // 2,
+ ),
nn.BatchNorm2d(nb_channels),
- nn.ReLU(inplace = True),
+ nn.ReLU(inplace=True),
)
blocks = []
self.resnet_blocks = nn.Sequential(*blocks)
- self.post_process = nn.Conv2d(nb_channels, out_channels, kernel_size = 1)
+ self.post_process = nn.Conv2d(nb_channels, out_channels, kernel_size=1)
def forward(self, x):
x = self.pre_process(x)
x = self.post_process(x)
return x
+
######################################################################
-data_filename = 'path.dat'
+data_filename = "path.dat"
try:
input, targets = torch.load(data_filename)
- log_string('Data loaded.')
- assert input.size(0) == args.nb_for_train + args.nb_for_test and \
- input.size(1) == 10 and \
- input.size(2) == args.world_height and \
- input.size(3) == args.world_width and \
- \
- targets.size(0) == args.nb_for_train + args.nb_for_test and \
- targets.size(1) == 2 and \
- targets.size(2) == args.world_height and \
- targets.size(3) == args.world_width
+ log_string("Data loaded.")
+ assert (
+ input.size(0) == args.nb_for_train + args.nb_for_test
+ and input.size(1) == 10
+ and input.size(2) == args.world_height
+ and input.size(3) == args.world_width
+ and targets.size(0) == args.nb_for_train + args.nb_for_test
+ and targets.size(1) == 2
+ and targets.size(2) == args.world_height
+ and targets.size(3) == args.world_width
+ )
except FileNotFoundError:
- log_string('Generating data.')
+ log_string("Generating data.")
input, targets = create_maze_data(
- nb = args.nb_for_train + args.nb_for_test,
- h = args.world_height, w = args.world_width,
- nb_walls = args.world_nb_walls,
- traj_length = (100, 10000)
+ nb=args.nb_for_train + args.nb_for_test,
+ h=args.world_height,
+ w=args.world_width,
+ nb_walls=args.world_nb_walls,
+ traj_length=(100, 10000),
)
torch.save((input, targets), data_filename)
except:
- log_string('Error when loading data.')
+ log_string("Error when loading data.")
exit(1)
######################################################################
for n in vars(args):
- log_string(f'args.{n} {getattr(args, n)}')
+ log_string(f"args.{n} {getattr(args, n)}")
model = ResNet(
- in_channels = 10, out_channels = 2,
- nb_residual_blocks = args.nb_residual_blocks,
- nb_channels = args.nb_channels,
- kernel_size = args.kernel_size
+ in_channels=10,
+ out_channels=2,
+ nb_residual_blocks=args.nb_residual_blocks,
+ nb_channels=args.nb_channels,
+ kernel_size=args.kernel_size,
)
criterion = nn.MSELoss()
input, targets = input.to(device), targets.to(device)
-train_input, train_targets = input[:args.nb_for_train], targets[:args.nb_for_train]
-test_input, test_targets = input[args.nb_for_train:], targets[args.nb_for_train:]
+train_input, train_targets = input[: args.nb_for_train], targets[: args.nb_for_train]
+test_input, test_targets = input[args.nb_for_train :], targets[args.nb_for_train :]
mu, std = train_input.mean(), train_input.std()
train_input.sub_(mu).div_(std)
else:
lr = 1e-3
- optimizer = torch.optim.Adam(model.parameters(), lr = lr)
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
acc_train_loss = 0.0
- for input, targets in zip(train_input.split(args.batch_size),
- train_targets.split(args.batch_size)):
+ for input, targets in zip(
+ train_input.split(args.batch_size), train_targets.split(args.batch_size)
+ ):
output = model(input)
loss = criterion(output, targets)
test_loss = 0.0
- for input, targets in zip(test_input.split(args.batch_size),
- test_targets.split(args.batch_size)):
+ for input, targets in zip(
+ test_input.split(args.batch_size), test_targets.split(args.batch_size)
+ ):
output = model(input)
loss = criterion(output, targets)
test_loss += loss.item()
log_string(
- f'{e} acc_train_loss {acc_train_loss / (args.nb_for_train / args.batch_size)} test_loss {test_loss / (args.nb_for_test / args.batch_size)} ETA {eta.eta(e+1)}'
+ f"{e} acc_train_loss {acc_train_loss / (args.nb_for_train / args.batch_size)} test_loss {test_loss / (args.nb_for_test / args.batch_size)} ETA {eta.eta(e+1)}"
)
# save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8]))
# save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8]))
- save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8])[:, 0:2])
- save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8])[:, 0:2])
+ save_image(
+ f"train_{e:04d}.png",
+ train_input[:8],
+ train_targets[:8],
+ model(train_input[:8])[:, 0:2],
+ )
+ save_image(
+ f"test_{e:04d}.png",
+ test_input[:8],
+ test_targets[:8],
+ model(test_input[:8])[:, 0:2],
+ )
######################################################################