5 import torch, torchvision
8 from torch.nn import functional as F
10 model = nn.Sequential(
16 ######################################################################
19 a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
21 x = 0.4 * ((a-0.5) * 5 * math.pi).cos()
23 data = torch.cat((y, x), 1)
24 data = data @ torch.tensor([[1., -1.], [1., 1.]])
28 a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
29 x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5)
30 y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5)
31 data = torch.cat((y, x), 1)
34 ######################################################################
36 # data = data_spiral(1000)
37 data = data_zigzag(1000)
39 data = data - data.mean(0)
41 batch_size, nb_epochs = 100, 1000
42 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
43 criterion = nn.MSELoss()
45 for e in range(nb_epochs):
47 for input in data.split(batch_size):
48 noise = input.new(input.size()).normal_(0, 0.1)
49 output = model(input + noise)
50 loss = criterion(output, input)
51 acc_loss += loss.item()
55 if (e+1)%10 == 0: print(e+1, acc_loss)
57 ######################################################################
59 a = torch.linspace(-1.5, 1.5, 30)
60 x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1)
61 y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1)
62 grid = torch.cat((y, x), 2).view(-1, 2)
64 # Take the origins of the arrows on the part of grid closer than 0.1
65 # from the data points
66 dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0]
67 origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)]
69 field = model(origins).detach() - origins
71 ######################################################################
73 import matplotlib.pyplot as plt
76 ax = fig.add_subplot(1, 1, 1)
79 ax.set_xlim(-1.6, 1.6)
80 ax.set_ylim(-1.6, 1.6)
83 plot_field = ax.quiver(
84 origins[:, 0].numpy(), origins[:, 1].numpy(),
85 field[:, 0].numpy(), field[:, 1].numpy(),
86 units = 'xy', scale = 1,
87 width = 3e-3, headwidth = 25, headlength = 25
90 plot_data = ax.scatter(
91 data[:, 0].numpy(), data[:, 1].numpy(),
92 s = 1, color = 'tab:blue'
95 fig.savefig('denoising_field.pdf', bbox_inches='tight')
97 ######################################################################