train_input *= 1.2
-class WithResidual(nn.Module):
- def __init__(self, *f):
- super().__init__()
- self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
- def forward(self, x):
- return 0.5 * x + 0.5 * self.f(x)
-
-
model = nn.Sequential(
nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
######################################################################
-sg=25
+sg = 25
input, targets = train_input, train_targets
-grid = torch.linspace(-1.2,1.2,sg)
-grid = torch.cat((grid[:,None,None].expand(sg,sg,1),grid[None,:,None].expand(sg,sg,1)),-1).reshape(-1,2)
+grid = torch.linspace(-1.2, 1.2, sg)
+grid = torch.cat(
+ (grid[:, None, None].expand(sg, sg, 1), grid[None, :, None].expand(sg, sg, 1)), -1
+).reshape(-1, 2)
for l, m in enumerate(model):
with open(os.path.join(args.result_dir, f"warp_{l}.tex"), "w") as f:
f.write(f"{input[k,0]} {input[k,1]} {targets[k]}\n")
f.write("};\n")
- g = grid.reshape(sg,sg,-1)
+ g = grid.reshape(sg, sg, -1)
for i in range(g.size(0)):
for j in range(g.size(1)):
if j == 0:
- pre="\\draw[black!25,very thin] "
+ pre = "\\draw[black!25,very thin] "
else:
- pre="--"
+ pre = "--"
f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})")
f.write(";\n")
for j in range(g.size(1)):
for i in range(g.size(0)):
if i == 0:
- pre="\\draw[black!25,very thin] "
+ pre = "\\draw[black!25,very thin] "
else:
- pre="--"
+ pre = "--"
f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})")
f.write(";\n")