projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
e111e52
)
Update.
author
Francois Fleuret
<francois@fleuret.org>
Wed, 5 Sep 2018 07:21:01 +0000
(09:21 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Wed, 5 Sep 2018 07:21:01 +0000
(09:21 +0200)
hallu.py
patch
|
blob
|
history
diff --git
a/hallu.py
b/hallu.py
index
6b0b303
..
7da66a6
100755
(executable)
--- a/
hallu.py
+++ b/
hallu.py
@@
-12,7
+12,7
@@
class MultiScaleEdgeEnergy(torch.nn.Module):
super(MultiScaleEdgeEnergy, self).__init__()
k = torch.exp(- torch.tensor([[-2., -1., 0., 1., 2.]])**2 / 2)
k = (k.t() @ k).view(1, 1, 5, 5)
super(MultiScaleEdgeEnergy, self).__init__()
k = torch.exp(- torch.tensor([[-2., -1., 0., 1., 2.]])**2 / 2)
k = (k.t() @ k).view(1, 1, 5, 5)
- self.
register_buffer('gaussian_5x5', k / k.sum()
)
+ self.
gaussian_5x5 = torch.nn.Parameter(k / k.sum()).requires_grad_(False
)
def forward(self, x):
u = x.view(-1, 1, x.size(2), x.size(3))
def forward(self, x):
u = x.view(-1, 1, x.size(2), x.size(3))
@@
-43,7
+43,7
@@
for l in [ 5, 7, 12, 17, 21, 28 ]:
ref_output = model(ref_input).detach()
for n in range(5):
ref_output = model(ref_input).detach()
for n in range(5):
- input =
ref_input.new_empty(ref_input.size()
).uniform_(-0.01, 0.01).requires_grad_()
+ input =
torch.empty_like(ref_input
).uniform_(-0.01, 0.01).requires_grad_()
optimizer = torch.optim.Adam( [ input ], lr = 1e-2)
for k in range(1000):
output = model(input)
optimizer = torch.optim.Adam( [ input ], lr = 1e-2)
for k in range(1000):
output = model(input)