From c7b7f15232cb5507793ef9fbe5bbb3084ea88346 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 16 Jun 2023 16:50:51 +0200 Subject: [PATCH] Update. --- warp.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/warp.py b/warp.py index 96dfa11..6212887 100755 --- a/warp.py +++ b/warp.py @@ -49,15 +49,6 @@ train_input[:, 1] -= 0.15 * (train_targets * 2 - 1) train_input *= 1.2 -class WithResidual(nn.Module): - def __init__(self, *f): - super().__init__() - self.f = f[0] if len(f) == 1 else nn.Sequential(*f) - - def forward(self, x): - return 0.5 * x + 0.5 * self.f(x) - - model = nn.Sequential( nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()), nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()), @@ -110,12 +101,14 @@ for k in range(nb_epochs): ###################################################################### -sg=25 +sg = 25 input, targets = train_input, train_targets -grid = torch.linspace(-1.2,1.2,sg) -grid = torch.cat((grid[:,None,None].expand(sg,sg,1),grid[None,:,None].expand(sg,sg,1)),-1).reshape(-1,2) +grid = torch.linspace(-1.2, 1.2, sg) +grid = torch.cat( + (grid[:, None, None].expand(sg, sg, 1), grid[None, :, None].expand(sg, sg, 1)), -1 +).reshape(-1, 2) for l, m in enumerate(model): with open(os.path.join(args.result_dir, f"warp_{l}.tex"), "w") as f: @@ -133,22 +126,22 @@ x y label f.write(f"{input[k,0]} {input[k,1]} {targets[k]}\n") f.write("};\n") - g = grid.reshape(sg,sg,-1) + g = grid.reshape(sg, sg, -1) for i in range(g.size(0)): for j in range(g.size(1)): if j == 0: - pre="\\draw[black!25,very thin] " + pre = "\\draw[black!25,very thin] " else: - pre="--" + pre = "--" f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})") f.write(";\n") for j in range(g.size(1)): for i in range(g.size(0)): if i == 0: - pre="\\draw[black!25,very thin] " + pre = "\\draw[black!25,very thin] " else: - pre="--" + pre = "--" f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})") f.write(";\n") -- 2.39.5