Update.
[pytorch.git] / redshift.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 torch.set_default_dtype(torch.float64)
11
12 res = 256
13 nh = 100
14
15 input = torch.cat(
16     [
17         torch.linspace(-1, 1, res)[None, :, None].expand(res, res, 1),
18         torch.linspace(-1, 1, res)[:, None, None].expand(res, res, 1),
19     ],
20     dim=-1,
21 ).reshape(-1, 2)
22
23
24 class Angles(nn.Module):
25     def forward(self, x):
26         return x.clamp(min=-0.5, max=0.5)
27
28
29 for activation in [nn.ReLU, nn.Tanh, nn.Softplus, Angles]:
30     for s in [1.0, 10.0]:
31         layers = [nn.Linear(2, nh), activation()]
32         nb_hidden = 4
33         for k in range(nb_hidden):
34             layers += [nn.Linear(nh, nh), activation()]
35         layers += [nn.Linear(nh, 2)]
36         model = nn.Sequential(*layers)
37
38         with torch.no_grad():
39             for p in model.parameters():
40                 p *= s
41
42         output = model(input)
43
44         img = (output[:, 1] - output[:, 0]).reshape(1, 1, res, res)
45
46         img = (img - img.mean()) / (1 * img.std())
47
48         img = img.clamp(min=-1, max=1)
49
50         img = torch.cat(
51             [
52                 (1 + img).clamp(max=1),
53                 (1 - img.abs()).clamp(min=0),
54                 (1 - img).clamp(max=1),
55             ],
56             dim=1,
57         )
58
59         name_activation = {
60             nn.ReLU: "relu",
61             nn.Tanh: "tanh",
62             nn.Softplus: "softplus",
63             Angles: "angles",
64         }[activation]
65
66         torchvision.utils.save_image(img, f"result-{name_activation}-{s}.png")