Update.
[pytorch.git] / redshift.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 torch.set_default_dtype(torch.float64)
11
12 nb_hidden = 5
13 hidden_dim = 100
14
15 res = 256
16
17 input = torch.cat(
18     [
19         torch.linspace(-1, 1, res)[None, :, None].expand(res, res, 1),
20         torch.linspace(-1, 1, res)[:, None, None].expand(res, res, 1),
21     ],
22     dim=-1,
23 ).reshape(-1, 2)
24
25
26 class Angles(nn.Module):
27     def forward(self, x):
28         return x.clamp(min=-0.5, max=0.5)
29
30
31 for activation in [nn.ReLU, nn.Tanh, nn.Softplus, Angles]:
32     for s in [1.0, 10.0]:
33         layers = [nn.Linear(2, hidden_dim), activation()]
34         for k in range(nb_hidden - 1):
35             layers += [nn.Linear(hidden_dim, hidden_dim), activation()]
36         layers += [nn.Linear(hidden_dim, 2)]
37         model = nn.Sequential(*layers)
38
39         with torch.no_grad():
40             for p in model.parameters():
41                 p *= s
42
43         output = model(input)
44
45         img = (output[:, 1] - output[:, 0]).reshape(1, 1, res, res)
46
47         img = (img - img.mean()) / (1 * img.std())
48
49         img = img.clamp(min=-1, max=1)
50
51         img = torch.cat(
52             [
53                 (1 + img).clamp(max=1),
54                 (1 - img.abs()).clamp(min=0),
55                 (1 - img).clamp(max=1),
56             ],
57             dim=1,
58         )
59
60         name_activation = {
61             nn.ReLU: "relu",
62             nn.Tanh: "tanh",
63             nn.Softplus: "softplus",
64             Angles: "angles",
65         }[activation]
66
67         torchvision.utils.save_image(img, f"result-{name_activation}-{s}.png")