b738b521ebe1008958f4e44976e938d4af3c0db4
[pytorch.git] / hallu.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # ImageMagick's montage to make the mosaic
9 #
10 # montage hallu-*.png -tile 5x6 -geometry +1+1 result.png
11
12 import PIL, torch, torchvision
13 from torch.nn import functional as F
14
15 class MultiScaleEdgeEnergy(torch.nn.Module):
16     def __init__(self):
17         super(MultiScaleEdgeEnergy, self).__init__()
18         k = torch.exp(- torch.tensor([[-2., -1., 0., 1., 2.]])**2 / 2)
19         k = (k.t() @ k).view(1, 1, 5, 5)
20         self.gaussian_5x5 = torch.nn.Parameter(k / k.sum()).requires_grad_(False)
21
22     def forward(self, x):
23         u = x.view(-1, 1, x.size(2), x.size(3))
24         result = 0.0
25         while min(u.size(2), u.size(3)) > 5:
26             blurry  = F.conv2d(u, self.gaussian_5x5, padding = 2)
27             result += (u - blurry).view(u.size(0), -1).pow(2).sum(1)
28             u = F.avg_pool2d(u, kernel_size = 2, padding = 1)
29         return result.view(x.size(0), -1).sum(1)
30
31 img = torchvision.transforms.ToTensor()(PIL.Image.open('blacklab.jpg'))
32 img = img.view((1,) + img.size())
33 ref_input = 0.5 + 0.5 * (img - img.mean()) / img.std()
34
35 mse_loss = torch.nn.MSELoss()
36 edge_energy = MultiScaleEdgeEnergy()
37
38 layers = torchvision.models.vgg16(pretrained = True).features
39 layers.eval()
40
41 if torch.cuda.is_available():
42     edge_energy.cuda()
43     ref_input = ref_input.cuda()
44     layers.cuda()
45
46 for l in [ 5, 7, 12, 17, 21, 28 ]:
47     model = torch.nn.Sequential(layers[:l])
48     ref_output = model(ref_input).detach()
49
50     for n in range(5):
51         input = torch.empty_like(ref_input).uniform_(-0.01, 0.01).requires_grad_()
52         optimizer = torch.optim.Adam( [ input ], lr = 1e-2)
53         for k in range(1000):
54             output = model(input)
55             loss = mse_loss(output, ref_output) + 1e-3 * edge_energy(input)
56             optimizer.zero_grad()
57             loss.backward()
58             optimizer.step()
59
60         img = 0.5 + 0.2 * (input - input.mean()) / input.std()
61         result_name = 'hallu-l%02d-n%02d.png' % (l, n)
62         torchvision.utils.save_image(img, result_name)
63
64         print('Wrote ' + result_name)