3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 import matplotlib.pyplot as plt
14 ######################################################################
17 def data_rectangle(nb):
18 x = torch.rand(nb, 1) - 0.5
19 y = torch.rand(nb, 1) * 2 - 1
20 data = torch.cat((y, x), 1)
22 data = data @ torch.tensor(
23 [[math.cos(alpha), math.sin(alpha)], [-math.sin(alpha), math.cos(alpha)]]
25 return data, "rectangle"
29 a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
31 x = 0.4 * ((a - 0.5) * 5 * math.pi).cos()
33 data = torch.cat((y, x), 1)
34 data = data @ torch.tensor([[1.0, -1.0], [1.0, 1.0]])
39 a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
40 x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5)
41 y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5)
42 data = torch.cat((y, x), 1)
47 a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
50 data = torch.cat((y, x), 1)
51 data = data + data.new(data.size()).normal_(0, 0.05)
55 ######################################################################
58 def train_model(data):
59 model = nn.Sequential(nn.Linear(2, 100), nn.ReLU(), nn.Linear(100, 2))
61 batch_size, nb_epochs = 100, 1000
62 optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
63 criterion = nn.MSELoss()
65 for e in range(nb_epochs):
67 for input in data.split(batch_size):
68 noise = input.new(input.size()).normal_(0, 0.1)
69 output = model(input + noise)
70 loss = criterion(output, input)
71 acc_loss += loss.item()
75 if (e + 1) % 100 == 0:
76 print(e + 1, acc_loss)
81 ######################################################################
84 def save_image(data_name, model, data):
85 a = torch.linspace(-1.5, 1.5, 30)
86 x = a.view(1, -1, 1).expand(a.size(0), a.size(0), 1)
87 y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1)
88 grid = torch.cat((y, x), 2).view(-1, 2)
90 # Take the origins of the arrows on the part of the grid closer than
91 # sqrt(0.1) to the data points
92 dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0]
93 origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)]
95 field = model(origins).detach() - origins
98 ax = fig.add_subplot(1, 1, 1)
101 ax.set_xlim(-1.6, 1.6)
102 ax.set_ylim(-1.6, 1.6)
105 plot_field = ax.quiver(
106 origins[:, 0].numpy(),
107 origins[:, 1].numpy(),
117 plot_data = ax.scatter(
118 data[:, 0].numpy(), data[:, 1].numpy(), s=1, color="tab:blue"
121 filename = f"denoising_field_{data_name}.pdf"
122 print(f"Saving {filename}")
123 fig.savefig(filename, bbox_inches="tight")
126 ######################################################################
128 for data_source in [data_rectangle, data_zigzag, data_spiral, data_penta]:
129 data, data_name = data_source(1000)
130 data = data - data.mean(0)
131 model = train_model(data)
132 save_image(data_name, model, data)