3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 import matplotlib.pyplot as plt
14 ######################################################################
17 a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
19 x = 0.4 * ((a-0.5) * 5 * math.pi).cos()
21 data = torch.cat((y, x), 1)
22 data = data @ torch.tensor([[1., -1.], [1., 1.]])
26 a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
27 x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5)
28 y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5)
29 data = torch.cat((y, x), 1)
33 a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
36 data = torch.cat((y, x), 1)
37 data = data + data.new(data.size()).normal_(0, 0.05)
40 ######################################################################
42 def train_model(data):
43 model = nn.Sequential(
49 batch_size, nb_epochs = 100, 1000
50 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
51 criterion = nn.MSELoss()
53 for e in range(nb_epochs):
55 for input in data.split(batch_size):
56 noise = input.new(input.size()).normal_(0, 0.1)
57 output = model(input + noise)
58 loss = criterion(output, input)
59 acc_loss += loss.item()
63 if (e+1)%100 == 0: print(e+1, acc_loss)
67 ######################################################################
69 def save_image(data_name, model, data):
70 a = torch.linspace(-1.5, 1.5, 30)
71 x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1)
72 y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1)
73 grid = torch.cat((y, x), 2).view(-1, 2)
75 # Take the origins of the arrows on the part of the grid closer than
76 # sqrt(0.1) to the data points
77 dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0]
78 origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)]
80 field = model(origins).detach() - origins
83 ax = fig.add_subplot(1, 1, 1)
86 ax.set_xlim(-1.6, 1.6)
87 ax.set_ylim(-1.6, 1.6)
90 plot_field = ax.quiver(
91 origins[:, 0].numpy(), origins[:, 1].numpy(),
92 field[:, 0].numpy(), field[:, 1].numpy(),
93 units = 'xy', scale = 1,
94 width = 3e-3, headwidth = 25, headlength = 25
97 plot_data = ax.scatter(
98 data[:, 0].numpy(), data[:, 1].numpy(),
99 s = 1, color = 'tab:blue'
102 filename = f'denoising_field_{data_name}.pdf'
103 print(f'Saving {filename}')
104 fig.savefig(filename, bbox_inches='tight')
106 ######################################################################
108 for data_source in [ data_zigzag, data_spiral, data_penta ]:
109 data, data_name = data_source(1000)
110 data = data - data.mean(0)
111 model = train_model(data)
112 save_image(data_name, model, data)