projects
/
tex.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
db5fb7d
)
Update.
author
François Fleuret
<francois@fleuret.org>
Fri, 16 Jun 2023 14:50:51 +0000
(16:50 +0200)
committer
François Fleuret
<francois@fleuret.org>
Fri, 16 Jun 2023 14:50:51 +0000
(16:50 +0200)
warp.py
patch
|
blob
|
history
diff --git
a/warp.py
b/warp.py
index
96dfa11
..
6212887
100755
(executable)
--- a/
warp.py
+++ b/
warp.py
@@
-49,15
+49,6
@@
train_input[:, 1] -= 0.15 * (train_targets * 2 - 1)
train_input *= 1.2
train_input *= 1.2
-class WithResidual(nn.Module):
- def __init__(self, *f):
- super().__init__()
- self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
- def forward(self, x):
- return 0.5 * x + 0.5 * self.f(x)
-
-
model = nn.Sequential(
nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
model = nn.Sequential(
nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
@@
-110,12
+101,14
@@
for k in range(nb_epochs):
######################################################################
######################################################################
-sg
=
25
+sg
=
25
input, targets = train_input, train_targets
input, targets = train_input, train_targets
-grid = torch.linspace(-1.2,1.2,sg)
-grid = torch.cat((grid[:,None,None].expand(sg,sg,1),grid[None,:,None].expand(sg,sg,1)),-1).reshape(-1,2)
+grid = torch.linspace(-1.2, 1.2, sg)
+grid = torch.cat(
+ (grid[:, None, None].expand(sg, sg, 1), grid[None, :, None].expand(sg, sg, 1)), -1
+).reshape(-1, 2)
for l, m in enumerate(model):
with open(os.path.join(args.result_dir, f"warp_{l}.tex"), "w") as f:
for l, m in enumerate(model):
with open(os.path.join(args.result_dir, f"warp_{l}.tex"), "w") as f:
@@
-133,22
+126,22
@@
x y label
f.write(f"{input[k,0]} {input[k,1]} {targets[k]}\n")
f.write("};\n")
f.write(f"{input[k,0]} {input[k,1]} {targets[k]}\n")
f.write("};\n")
- g = grid.reshape(sg,
sg,
-1)
+ g = grid.reshape(sg,
sg,
-1)
for i in range(g.size(0)):
for j in range(g.size(1)):
if j == 0:
for i in range(g.size(0)):
for j in range(g.size(1)):
if j == 0:
- pre
=
"\\draw[black!25,very thin] "
+ pre
=
"\\draw[black!25,very thin] "
else:
else:
- pre
=
"--"
+ pre
=
"--"
f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})")
f.write(";\n")
for j in range(g.size(1)):
for i in range(g.size(0)):
if i == 0:
f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})")
f.write(";\n")
for j in range(g.size(1)):
for i in range(g.size(0)):
if i == 0:
- pre
=
"\\draw[black!25,very thin] "
+ pre
=
"\\draw[black!25,very thin] "
else:
else:
- pre
=
"--"
+ pre
=
"--"
f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})")
f.write(";\n")
f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})")
f.write(";\n")