Minor syntactic change.
[pytorch] / denoising-ae-field.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import math
9 import matplotlib.pyplot as plt
10
11 import torch
12 from torch import nn
13
14 ######################################################################
15
16 def data_rectangle(nb):
17     x = torch.rand(nb, 1) - 0.5
18     y = torch.rand(nb, 1) * 2 - 1
19     data = torch.cat((y, x), 1)
20     alpha = math.pi / 8
21     data = data @ torch.tensor(
22         [
23             [ math.cos(alpha), math.sin(alpha)],
24             [-math.sin(alpha), math.cos(alpha)]
25         ]
26     )
27     return data, 'rectangle'
28
29 def data_zigzag(nb):
30     a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
31     # zigzag
32     x = 0.4 * ((a-0.5) * 5 * math.pi).cos()
33     y = a * 2.5 - 1.25
34     data = torch.cat((y, x), 1)
35     data = data @ torch.tensor([[1., -1.], [1., 1.]])
36     return data, 'zigzag'
37
38 def data_spiral(nb):
39     a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
40     x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5)
41     y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5)
42     data = torch.cat((y, x), 1)
43     return data, 'spiral'
44
45 def data_penta(nb):
46     a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
47     x = a.cos()
48     y = a.sin()
49     data = torch.cat((y, x), 1)
50     data = data + data.new(data.size()).normal_(0, 0.05)
51     return data, 'penta'
52
53 ######################################################################
54
55 def train_model(data):
56     model = nn.Sequential(
57         nn.Linear(2, 100),
58         nn.ReLU(),
59         nn.Linear(100, 2)
60     )
61
62     batch_size, nb_epochs = 100, 1000
63     optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
64     criterion = nn.MSELoss()
65
66     for e in range(nb_epochs):
67         acc_loss = 0
68         for input in data.split(batch_size):
69             noise = input.new(input.size()).normal_(0, 0.1)
70             output = model(input + noise)
71             loss = criterion(output, input)
72             acc_loss += loss.item()
73             optimizer.zero_grad()
74             loss.backward()
75             optimizer.step()
76         if (e+1)%100 == 0: print(e+1, acc_loss)
77
78     return model
79
80 ######################################################################
81
82 def save_image(data_name, model, data):
83     a = torch.linspace(-1.5, 1.5, 30)
84     x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1)
85     y = a.view(-1,  1, 1).expand(a.size(0), a.size(0), 1)
86     grid = torch.cat((y, x), 2).view(-1, 2)
87
88     # Take the origins of the arrows on the part of the grid closer than
89     # sqrt(0.1) to the data points
90     dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0]
91     origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)]
92
93     field = model(origins).detach() - origins
94
95     fig = plt.figure()
96     ax = fig.add_subplot(1, 1, 1)
97
98     ax.axis('off')
99     ax.set_xlim(-1.6, 1.6)
100     ax.set_ylim(-1.6, 1.6)
101     ax.set_aspect(1)
102
103     plot_field = ax.quiver(
104         origins[:, 0].numpy(), origins[:, 1].numpy(),
105         field[:, 0].numpy(), field[:, 1].numpy(),
106         units = 'xy', scale = 1,
107         width = 3e-3, headwidth = 25, headlength = 25
108     )
109
110     plot_data = ax.scatter(
111         data[:, 0].numpy(), data[:, 1].numpy(),
112         s = 1, color = 'tab:blue'
113     )
114
115     filename = f'denoising_field_{data_name}.pdf'
116     print(f'Saving {filename}')
117     fig.savefig(filename, bbox_inches='tight')
118
119 ######################################################################
120
121 for data_source in [ data_rectangle, data_zigzag, data_spiral, data_penta ]:
122     data, data_name = data_source(1000)
123     data = data - data.mean(0)
124     model = train_model(data)
125     save_image(data_name, model, data)