Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 16 Jun 2023 14:50:51 +0000 (16:50 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 16 Jun 2023 14:50:51 +0000 (16:50 +0200)
warp.py

diff --git a/warp.py b/warp.py
index 96dfa11..6212887 100755 (executable)
--- a/warp.py
+++ b/warp.py
@@ -49,15 +49,6 @@ train_input[:, 1] -= 0.15 * (train_targets * 2 - 1)
 train_input *= 1.2
 
 
-class WithResidual(nn.Module):
-    def __init__(self, *f):
-        super().__init__()
-        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
-    def forward(self, x):
-        return 0.5 * x + 0.5 * self.f(x)
-
-
 model = nn.Sequential(
     nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
     nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
@@ -110,12 +101,14 @@ for k in range(nb_epochs):
 
 ######################################################################
 
-sg=25
+sg = 25
 
 input, targets = train_input, train_targets
 
-grid = torch.linspace(-1.2,1.2,sg)
-grid = torch.cat((grid[:,None,None].expand(sg,sg,1),grid[None,:,None].expand(sg,sg,1)),-1).reshape(-1,2)
+grid = torch.linspace(-1.2, 1.2, sg)
+grid = torch.cat(
+    (grid[:, None, None].expand(sg, sg, 1), grid[None, :, None].expand(sg, sg, 1)), -1
+).reshape(-1, 2)
 
 for l, m in enumerate(model):
     with open(os.path.join(args.result_dir, f"warp_{l}.tex"), "w") as f:
@@ -133,22 +126,22 @@ x y label
             f.write(f"{input[k,0]} {input[k,1]} {targets[k]}\n")
         f.write("};\n")
 
-        g = grid.reshape(sg,sg,-1)
+        g = grid.reshape(sg, sg, -1)
         for i in range(g.size(0)):
             for j in range(g.size(1)):
                 if j == 0:
-                    pre="\\draw[black!25,very thin] "
+                    pre = "\\draw[black!25,very thin] "
                 else:
-                    pre="--"
+                    pre = "--"
                 f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})")
             f.write(";\n")
 
         for j in range(g.size(1)):
             for i in range(g.size(0)):
                 if i == 0:
-                    pre="\\draw[black!25,very thin] "
+                    pre = "\\draw[black!25,very thin] "
                 else:
-                    pre="--"
+                    pre = "--"
                 f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})")
             f.write(";\n")