Added the copyright comment + reformatted with black. master
authorFrançois Fleuret <francois@fleuret.org>
Wed, 15 Feb 2023 19:50:16 +0000 (20:50 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 15 Feb 2023 19:50:16 +0000 (20:50 +0100)
path.py

diff --git a/path.py b/path.py
index c26afed..866eeb9 100755 (executable)
--- a/path.py
+++ b/path.py
@@ -1,5 +1,10 @@
 #!/usr/bin/env python
 
 #!/usr/bin/env python
 
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
 import sys, math, time, argparse
 
 import torch, torchvision
 import sys, math, time, argparse
 
 import torch, torchvision
@@ -10,43 +15,33 @@ from torch.nn import functional as F
 ######################################################################
 
 parser = argparse.ArgumentParser(
 ######################################################################
 
 parser = argparse.ArgumentParser(
-    description='Path-planning as denoising.',
-    formatter_class = argparse.ArgumentDefaultsHelpFormatter
+    description="Path-planning as denoising.",
+    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
 )
 
-parser.add_argument('--nb_epochs',
-                    type = int, default = 25)
+parser.add_argument("--nb_epochs", type=int, default=25)
 
 
-parser.add_argument('--batch_size',
-                    type = int, default = 100)
+parser.add_argument("--batch_size", type=int, default=100)
 
 
-parser.add_argument('--nb_residual_blocks',
-                    type = int, default = 16)
+parser.add_argument("--nb_residual_blocks", type=int, default=16)
 
 
-parser.add_argument('--nb_channels',
-                    type = int, default = 128)
+parser.add_argument("--nb_channels", type=int, default=128)
 
 
-parser.add_argument('--kernel_size',
-                    type = int, default = 3)
+parser.add_argument("--kernel_size", type=int, default=3)
 
 
-parser.add_argument('--nb_for_train',
-                    type = int, default = 100000)
+parser.add_argument("--nb_for_train", type=int, default=100000)
 
 
-parser.add_argument('--nb_for_test',
-                    type = int, default = 10000)
+parser.add_argument("--nb_for_test", type=int, default=10000)
 
 
-parser.add_argument('--world_height',
-                    type = int, default = 23)
+parser.add_argument("--world_height", type=int, default=23)
 
 
-parser.add_argument('--world_width',
-                    type = int, default = 31)
+parser.add_argument("--world_width", type=int, default=31)
 
 
-parser.add_argument('--world_nb_walls',
-                    type = int, default = 15)
+parser.add_argument("--world_nb_walls", type=int, default=15)
 
 
-parser.add_argument('--seed',
-                    type = int, default = 0,
-                    help = 'Random seed (default 0, < 0 is no seeding)')
+parser.add_argument(
+    "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
+)
 
 ######################################################################
 
 
 ######################################################################
 
@@ -57,24 +52,27 @@ if args.seed >= 0:
 
 ######################################################################
 
 
 ######################################################################
 
-label=''
+label = ""
 
 
-log_file = open(f'path_{label}train.log', 'w')
+log_file = open(f"path_{label}train.log", "w")
 
 ######################################################################
 
 
 ######################################################################
 
+
 def log_string(s):
 def log_string(s):
-    t = time.strftime('%Y%m%d-%H:%M:%S', time.localtime())
-    s = t + ' - ' + s
+    t = time.strftime("%Y%m%d-%H:%M:%S", time.localtime())
+    s = t + " - " + s
     if log_file is not None:
     if log_file is not None:
-        log_file.write(s + '\n')
+        log_file.write(s + "\n")
         log_file.flush()
 
     print(s)
     sys.stdout.flush()
 
         log_file.flush()
 
     print(s)
     sys.stdout.flush()
 
+
 ######################################################################
 
 ######################################################################
 
+
 class ETA:
     def __init__(self, n):
         self.n = n
 class ETA:
     def __init__(self, n):
         self.n = n
@@ -84,60 +82,74 @@ class ETA:
         if k > 0:
             t = time.time()
             u = self.t0 + ((t - self.t0) * self.n) // k
         if k > 0:
             t = time.time()
             u = self.t0 + ((t - self.t0) * self.n) // k
-            return time.strftime('%Y%m%d-%H:%M:%S', time.localtime(u))
+            return time.strftime("%Y%m%d-%H:%M:%S", time.localtime(u))
         else:
             return "n.a."
 
         else:
             return "n.a."
 
+
 ######################################################################
 
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 ######################################################################
 
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
-log_string(f'device {device}')
+log_string(f"device {device}")
 
 ######################################################################
 
 
 ######################################################################
 
-def create_maze(h = 11, w = 15, nb_walls = 10):
+
+def create_maze(h=11, w=15, nb_walls=10):
     a, k = 0, 0
 
     while k < nb_walls:
         while True:
             if a == 0:
     a, k = 0, 0
 
     while k < nb_walls:
         while True:
             if a == 0:
-                m = torch.zeros(h, w, dtype = torch.int64)
-                m[ 0,  :] = 1
-                m[-1,  :] = 1
-                m[ :,  0] = 1
-                m[ :, -1] = 1
+                m = torch.zeros(h, w, dtype=torch.int64)
+                m[0, :] = 1
+                m[-1, :] = 1
+                m[:, 0] = 1
+                m[:, -1] = 1
 
             r = torch.rand(4)
 
             if r[0] <= 0.5:
 
             r = torch.rand(4)
 
             if r[0] <= 0.5:
-                i1, i2, j = int((r[1] * h).item()), int((r[2] * h).item()), int((r[3] * w).item())
-                i1, i2, j = i1 - i1%2, i2 - i2%2, j - j%2
+                i1, i2, j = (
+                    int((r[1] * h).item()),
+                    int((r[2] * h).item()),
+                    int((r[3] * w).item()),
+                )
+                i1, i2, j = i1 - i1 % 2, i2 - i2 % 2, j - j % 2
                 i1, i2 = min(i1, i2), max(i1, i2)
                 i1, i2 = min(i1, i2), max(i1, i2)
-                if i2 - i1 > 1 and i2 - i1 <= h/2 and m[i1:i2+1, j].sum() <= 1:
-                    m[i1:i2+1, j] = 1
+                if i2 - i1 > 1 and i2 - i1 <= h / 2 and m[i1 : i2 + 1, j].sum() <= 1:
+                    m[i1 : i2 + 1, j] = 1
                     break
             else:
                     break
             else:
-                i, j1, j2 = int((r[1] * h).item()), int((r[2] * w).item()), int((r[3] * w).item())
-                i, j1, j2 = i - i%2, j1 - j1%2, j2 - j2%2
+                i, j1, j2 = (
+                    int((r[1] * h).item()),
+                    int((r[2] * w).item()),
+                    int((r[3] * w).item()),
+                )
+                i, j1, j2 = i - i % 2, j1 - j1 % 2, j2 - j2 % 2
                 j1, j2 = min(j1, j2), max(j1, j2)
                 j1, j2 = min(j1, j2), max(j1, j2)
-                if j2 - j1 > 1 and j2 - j1 <= w/2 and m[i, j1:j2+1].sum() <= 1:
-                    m[i, j1:j2+1] = 1
+                if j2 - j1 > 1 and j2 - j1 <= w / 2 and m[i, j1 : j2 + 1].sum() <= 1:
+                    m[i, j1 : j2 + 1] = 1
                     break
             a += 1
 
                     break
             a += 1
 
-            if a > 10 * nb_walls: a, k = 0, 0
+            if a > 10 * nb_walls:
+                a, k = 0, 0
 
         k += 1
 
     return m
 
 
         k += 1
 
     return m
 
+
 ######################################################################
 
 ######################################################################
 
+
 def random_free_position(walls):
     p = torch.randperm(walls.numel())
     k = p[walls.view(-1)[p] == 0][0].item()
 def random_free_position(walls):
     p = torch.randperm(walls.numel())
     k = p[walls.view(-1)[p] == 0][0].item()
-    return k//walls.size(1), k%walls.size(1)
+    return k // walls.size(1), k % walls.size(1)
+
 
 def create_transitions(walls, nb):
     trans = walls.new_zeros((9,) + walls.size())
 
 def create_transitions(walls, nb):
     trans = walls.new_zeros((9,) + walls.size())
@@ -145,24 +157,30 @@ def create_transitions(walls, nb):
     i, j = random_free_position(walls)
 
     for k in range(t.size(0)):
     i, j = random_free_position(walls)
 
     for k in range(t.size(0)):
-        di, dj = [ (0, 1), (1, 0), (0, -1), (-1, 0) ][t[k]]
+        di, dj = [(0, 1), (1, 0), (0, -1), (-1, 0)][t[k]]
         ip, jp = i + di, j + dj
         ip, jp = i + di, j + dj
-        if ip < 0 or ip >= walls.size(0) or \
-           jp < 0 or jp >= walls.size(1) or \
-           walls[ip, jp] > 0:
+        if (
+            ip < 0
+            or ip >= walls.size(0)
+            or jp < 0
+            or jp >= walls.size(1)
+            or walls[ip, jp] > 0
+        ):
             trans[t[k] + 4, i, j] += 1
         else:
             trans[t[k], i, j] += 1
             i, j = ip, jp
 
             trans[t[k] + 4, i, j] += 1
         else:
             trans[t[k], i, j] += 1
             i, j = ip, jp
 
-    n = trans[0:8].sum(dim = 0, keepdim = True)
+    n = trans[0:8].sum(dim=0, keepdim=True)
     trans[8:9] = n
     trans[0:8] = trans[0:8] / (n + (n == 0).long())
 
     return trans
 
     trans[8:9] = n
     trans[0:8] = trans[0:8] / (n + (n == 0).long())
 
     return trans
 
+
 ######################################################################
 
 ######################################################################
 
+
 def compute_distance(walls, i, j):
     max_length = walls.numel()
     dist = torch.full_like(walls, max_length)
 def compute_distance(walls, i, j):
     max_length = walls.numel()
     dist = torch.full_like(walls, max_length)
@@ -172,41 +190,50 @@ def compute_distance(walls, i, j):
 
     while True:
         pred_dist.copy_(dist)
 
     while True:
         pred_dist.copy_(dist)
-        d = torch.cat(
-            (
-                dist[None, 1:-1, 0:-2],
-                dist[None, 2:, 1:-1],
-                dist[None, 1:-1, 2:],
-                dist[None, 0:-2, 1:-1]
-            ),
-        0).min(dim = 0)[0] + 1
+        d = (
+            torch.cat(
+                (
+                    dist[None, 1:-1, 0:-2],
+                    dist[None, 2:, 1:-1],
+                    dist[None, 1:-1, 2:],
+                    dist[None, 0:-2, 1:-1],
+                ),
+                0,
+            ).min(dim=0)[0]
+            + 1
+        )
 
         dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
         dist = walls * max_length + (1 - walls) * dist
 
 
         dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
         dist = walls * max_length + (1 - walls) * dist
 
-        if dist.equal(pred_dist): return dist * (1 - walls)
+        if dist.equal(pred_dist):
+            return dist * (1 - walls)
+
 
 ######################################################################
 
 
 ######################################################################
 
+
 def compute_policy(walls, i, j):
     distance = compute_distance(walls, i, j)
     distance = distance + walls.numel() * walls
 
     value = distance.new_full((4,) + distance.size(), walls.numel())
 def compute_policy(walls, i, j):
     distance = compute_distance(walls, i, j)
     distance = distance + walls.numel() * walls
 
     value = distance.new_full((4,) + distance.size(), walls.numel())
-    value[0,  :  , 1:  ] = distance[ :  ,  :-1]
-    value[1,  :  ,  :-1] = distance[ :  , 1:  ]
-    value[2, 1:  ,  :  ] = distance[ :-1,  :  ]
-    value[3,  :-1,  :  ] = distance[1:  ,  :  ]
+    value[0, :, 1:] = distance[:, :-1]
+    value[1, :, :-1] = distance[:, 1:]
+    value[2, 1:, :] = distance[:-1, :]
+    value[3, :-1, :] = distance[1:, :]
 
 
-    proba = (value.min(dim = 0)[0][None] == value).float()
-    proba = proba / proba.sum(dim = 0)[None]
+    proba = (value.min(dim=0)[0][None] == value).float()
+    proba = proba / proba.sum(dim=0)[None]
     proba = proba * (1 - walls)
 
     return proba
 
     proba = proba * (1 - walls)
 
     return proba
 
+
 ######################################################################
 
 ######################################################################
 
-def create_maze_data(nb, h = 11, w = 17, nb_walls = 8, traj_length = 50):
+
+def create_maze_data(nb, h=11, w=17, nb_walls=8, traj_length=50):
     input = torch.empty(nb, 10, h, w)
     targets = torch.empty(nb, 2, h, w)
 
     input = torch.empty(nb, 10, h, w)
     targets = torch.empty(nb, 2, h, w)
 
@@ -218,8 +245,8 @@ def create_maze_data(nb, h = 11, w = 17, nb_walls = 8, traj_length = 50):
     eta = ETA(nb)
 
     for n in range(nb):
     eta = ETA(nb)
 
     for n in range(nb):
-        if n%(max(10, nb//1000)) == 0:
-            log_string(f'{(100 * n)/nb:.02f}% ETA {eta.eta(n+1)}')
+        if n % (max(10, nb // 1000)) == 0:
+            log_string(f"{(100 * n)/nb:.02f}% ETA {eta.eta(n+1)}")
 
         walls = create_maze(h, w, nb_walls)
         trans = create_transitions(walls, l[n])
 
         walls = create_maze(h, w, nb_walls)
         trans = create_transitions(walls, l[n])
@@ -234,18 +261,21 @@ def create_maze_data(nb, h = 11, w = 17, nb_walls = 8, traj_length = 50):
 
     return input, targets
 
 
     return input, targets
 
+
 ######################################################################
 
 ######################################################################
 
-def save_image(name, input, targets, output = None):
+
+def save_image(name, input, targets, output=None):
     input, targets = input.cpu(), targets.cpu()
 
     weight = torch.tensor(
         [
     input, targets = input.cpu(), targets.cpu()
 
     weight = torch.tensor(
         [
-            [ 1.0, 0.0, 0.0 ],
-            [ 1.0, 1.0, 0.0 ],
-            [ 0.0, 1.0, 0.0 ],
-            [ 0.0, 0.0, 1.0 ],
-        ] ).t()[:, :, None, None]
+            [1.0, 0.0, 0.0],
+            [1.0, 1.0, 0.0],
+            [0.0, 1.0, 0.0],
+            [0.0, 0.0, 1.0],
+        ]
+    ).t()[:, :, None, None]
 
     # img_trans = F.conv2d(input[:, 0:5], weight)
     # img_trans = img_trans / img_trans.max()
 
     # img_trans = F.conv2d(input[:, 0:5], weight)
     # img_trans = img_trans / img_trans.max()
@@ -271,7 +301,7 @@ def save_image(name, input, targets, output = None):
         img_walls[:, None],
         # img_pi[:, None],
         img_dist[:, None],
         img_walls[:, None],
         # img_pi[:, None],
         img_dist[:, None],
-        )
+    )
 
     if output is not None:
         output = output.cpu()
 
     if output is not None:
         output = output.cpu()
@@ -299,24 +329,22 @@ def save_image(name, input, targets, output = None):
         img_all.size(4),
     )
 
         img_all.size(4),
     )
 
-    torchvision.utils.save_image(
-        img_all,
-        name,
-        padding = 1, pad_value = 0.5, nrow = len(img)
-    )
+    torchvision.utils.save_image(img_all, name, padding=1, pad_value=0.5, nrow=len(img))
+
+    log_string(f"Wrote {name}")
 
 
-    log_string(f'Wrote {name}')
 
 ######################################################################
 
 
 ######################################################################
 
+
 class Net(nn.Module):
     def __init__(self):
         super().__init__()
         nh = 128
 class Net(nn.Module):
     def __init__(self):
         super().__init__()
         nh = 128
-        self.conv1 = nn.Conv2d( 6, nh, kernel_size = 5, padding = 2)
-        self.conv2 = nn.Conv2d(nh, nh, kernel_size = 5, padding = 2)
-        self.conv3 = nn.Conv2d(nh, nh, kernel_size = 5, padding = 2)
-        self.conv4 = nn.Conv2d(nh,  2, kernel_size = 5, padding = 2)
+        self.conv1 = nn.Conv2d(6, nh, kernel_size=5, padding=2)
+        self.conv2 = nn.Conv2d(nh, nh, kernel_size=5, padding=2)
+        self.conv3 = nn.Conv2d(nh, nh, kernel_size=5, padding=2)
+        self.conv4 = nn.Conv2d(nh, 2, kernel_size=5, padding=2)
 
     def forward(self, x):
         x = F.relu(self.conv1(x))
 
     def forward(self, x):
         x = F.relu(self.conv1(x))
@@ -325,21 +353,29 @@ class Net(nn.Module):
         x = self.conv4(x)
         return x
 
         x = self.conv4(x)
         return x
 
+
 ######################################################################
 
 ######################################################################
 
+
 class ResNetBlock(nn.Module):
     def __init__(self, nb_channels, kernel_size):
         super().__init__()
 
 class ResNetBlock(nn.Module):
     def __init__(self, nb_channels, kernel_size):
         super().__init__()
 
-        self.conv1 = nn.Conv2d(nb_channels, nb_channels,
-                               kernel_size = kernel_size,
-                               padding = (kernel_size - 1) // 2)
+        self.conv1 = nn.Conv2d(
+            nb_channels,
+            nb_channels,
+            kernel_size=kernel_size,
+            padding=(kernel_size - 1) // 2,
+        )
 
         self.bn1 = nn.BatchNorm2d(nb_channels)
 
 
         self.bn1 = nn.BatchNorm2d(nb_channels)
 
-        self.conv2 = nn.Conv2d(nb_channels, nb_channels,
-                               kernel_size = kernel_size,
-                               padding = (kernel_size - 1) // 2)
+        self.conv2 = nn.Conv2d(
+            nb_channels,
+            nb_channels,
+            kernel_size=kernel_size,
+            padding=(kernel_size - 1) // 2,
+        )
 
         self.bn2 = nn.BatchNorm2d(nb_channels)
 
 
         self.bn2 = nn.BatchNorm2d(nb_channels)
 
@@ -348,19 +384,22 @@ class ResNetBlock(nn.Module):
         y = F.relu(x + self.bn2(self.conv2(y)))
         return y
 
         y = F.relu(x + self.bn2(self.conv2(y)))
         return y
 
-class ResNet(nn.Module):
 
 
-    def __init__(self,
-                 in_channels, out_channels,
-                 nb_residual_blocks, nb_channels, kernel_size):
+class ResNet(nn.Module):
+    def __init__(
+        self, in_channels, out_channels, nb_residual_blocks, nb_channels, kernel_size
+    ):
         super().__init__()
 
         self.pre_process = nn.Sequential(
         super().__init__()
 
         self.pre_process = nn.Sequential(
-            nn.Conv2d(in_channels, nb_channels,
-                      kernel_size = kernel_size,
-                      padding = (kernel_size - 1) // 2),
+            nn.Conv2d(
+                in_channels,
+                nb_channels,
+                kernel_size=kernel_size,
+                padding=(kernel_size - 1) // 2,
+            ),
             nn.BatchNorm2d(nb_channels),
             nn.BatchNorm2d(nb_channels),
-            nn.ReLU(inplace = True),
+            nn.ReLU(inplace=True),
         )
 
         blocks = []
         )
 
         blocks = []
@@ -369,7 +408,7 @@ class ResNet(nn.Module):
 
         self.resnet_blocks = nn.Sequential(*blocks)
 
 
         self.resnet_blocks = nn.Sequential(*blocks)
 
-        self.post_process = nn.Conv2d(nb_channels, out_channels, kernel_size = 1)
+        self.post_process = nn.Conv2d(nb_channels, out_channels, kernel_size=1)
 
     def forward(self, x):
         x = self.pre_process(x)
 
     def forward(self, x):
         x = self.pre_process(x)
@@ -377,49 +416,53 @@ class ResNet(nn.Module):
         x = self.post_process(x)
         return x
 
         x = self.post_process(x)
         return x
 
+
 ######################################################################
 
 ######################################################################
 
-data_filename = 'path.dat'
+data_filename = "path.dat"
 
 try:
     input, targets = torch.load(data_filename)
 
 try:
     input, targets = torch.load(data_filename)
-    log_string('Data loaded.')
-    assert input.size(0) == args.nb_for_train + args.nb_for_test and \
-           input.size(1) == 10 and \
-           input.size(2) == args.world_height and \
-           input.size(3) == args.world_width and \
-           \
-           targets.size(0) == args.nb_for_train + args.nb_for_test and \
-           targets.size(1) == 2 and \
-           targets.size(2) == args.world_height and \
-           targets.size(3) == args.world_width
+    log_string("Data loaded.")
+    assert (
+        input.size(0) == args.nb_for_train + args.nb_for_test
+        and input.size(1) == 10
+        and input.size(2) == args.world_height
+        and input.size(3) == args.world_width
+        and targets.size(0) == args.nb_for_train + args.nb_for_test
+        and targets.size(1) == 2
+        and targets.size(2) == args.world_height
+        and targets.size(3) == args.world_width
+    )
 
 except FileNotFoundError:
 
 except FileNotFoundError:
-    log_string('Generating data.')
+    log_string("Generating data.")
 
     input, targets = create_maze_data(
 
     input, targets = create_maze_data(
-        nb = args.nb_for_train + args.nb_for_test,
-        h = args.world_height, w = args.world_width,
-        nb_walls = args.world_nb_walls,
-        traj_length = (100, 10000)
+        nb=args.nb_for_train + args.nb_for_test,
+        h=args.world_height,
+        w=args.world_width,
+        nb_walls=args.world_nb_walls,
+        traj_length=(100, 10000),
     )
 
     torch.save((input, targets), data_filename)
 
 except:
     )
 
     torch.save((input, targets), data_filename)
 
 except:
-    log_string('Error when loading data.')
+    log_string("Error when loading data.")
     exit(1)
 
 ######################################################################
 
 for n in vars(args):
     exit(1)
 
 ######################################################################
 
 for n in vars(args):
-    log_string(f'args.{n} {getattr(args, n)}')
+    log_string(f"args.{n} {getattr(args, n)}")
 
 model = ResNet(
 
 model = ResNet(
-    in_channels = 10, out_channels = 2,
-    nb_residual_blocks = args.nb_residual_blocks,
-    nb_channels = args.nb_channels,
-    kernel_size = args.kernel_size
+    in_channels=10,
+    out_channels=2,
+    nb_residual_blocks=args.nb_residual_blocks,
+    nb_channels=args.nb_channels,
+    kernel_size=args.kernel_size,
 )
 
 criterion = nn.MSELoss()
 )
 
 criterion = nn.MSELoss()
@@ -429,8 +472,8 @@ criterion.to(device)
 
 input, targets = input.to(device), targets.to(device)
 
 
 input, targets = input.to(device), targets.to(device)
 
-train_input, train_targets = input[:args.nb_for_train], targets[:args.nb_for_train]
-test_input, test_targets   = input[args.nb_for_train:], targets[args.nb_for_train:]
+train_input, train_targets = input[: args.nb_for_train], targets[: args.nb_for_train]
+test_input, test_targets = input[args.nb_for_train :], targets[args.nb_for_train :]
 
 mu, std = train_input.mean(), train_input.std()
 train_input.sub_(mu).div_(std)
 
 mu, std = train_input.mean(), train_input.std()
 train_input.sub_(mu).div_(std)
@@ -447,12 +490,13 @@ for e in range(args.nb_epochs):
     else:
         lr = 1e-3
 
     else:
         lr = 1e-3
 
-    optimizer = torch.optim.Adam(model.parameters(), lr = lr)
+    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
 
     acc_train_loss = 0.0
 
 
     acc_train_loss = 0.0
 
-    for input, targets in zip(train_input.split(args.batch_size),
-                              train_targets.split(args.batch_size)):
+    for input, targets in zip(
+        train_input.split(args.batch_size), train_targets.split(args.batch_size)
+    ):
         output = model(input)
 
         loss = criterion(output, targets)
         output = model(input)
 
         loss = criterion(output, targets)
@@ -464,20 +508,31 @@ for e in range(args.nb_epochs):
 
     test_loss = 0.0
 
 
     test_loss = 0.0
 
-    for input, targets in zip(test_input.split(args.batch_size),
-                              test_targets.split(args.batch_size)):
+    for input, targets in zip(
+        test_input.split(args.batch_size), test_targets.split(args.batch_size)
+    ):
         output = model(input)
         loss = criterion(output, targets)
         test_loss += loss.item()
 
     log_string(
         output = model(input)
         loss = criterion(output, targets)
         test_loss += loss.item()
 
     log_string(
-        f'{e} acc_train_loss {acc_train_loss / (args.nb_for_train / args.batch_size)} test_loss {test_loss / (args.nb_for_test / args.batch_size)} ETA {eta.eta(e+1)}'
+        f"{e} acc_train_loss {acc_train_loss / (args.nb_for_train / args.batch_size)} test_loss {test_loss / (args.nb_for_test / args.batch_size)} ETA {eta.eta(e+1)}"
     )
 
     # save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8]))
     # save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8]))
 
     )
 
     # save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8]))
     # save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8]))
 
-    save_image(f'train_{e:04d}.png', train_input[:8], train_targets[:8], model(train_input[:8])[:, 0:2])
-    save_image(f'test_{e:04d}.png', test_input[:8], test_targets[:8], model(test_input[:8])[:, 0:2])
+    save_image(
+        f"train_{e:04d}.png",
+        train_input[:8],
+        train_targets[:8],
+        model(train_input[:8])[:, 0:2],
+    )
+    save_image(
+        f"test_{e:04d}.png",
+        test_input[:8],
+        test_targets[:8],
+        model(test_input[:8])[:, 0:2],
+    )
 
 ######################################################################
 
 ######################################################################