#!/usr/bin/env python
-import math
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
-import torch, torchvision
+import math
+import matplotlib.pyplot as plt
+import torch
from torch import nn
-from torch.nn import functional as F
-model = nn.Sequential(
- nn.Linear(2, 100),
- nn.ReLU(),
- nn.Linear(100, 2)
-)
+######################################################################
-############################################################
+def data_rectangle(nb):
+ x = torch.rand(nb, 1) - 0.5
+ y = torch.rand(nb, 1) * 2 - 1
+ data = torch.cat((y, x), 1)
+ alpha = math.pi / 8
+ data = data @ torch.tensor(
+ [
+ [ math.cos(alpha), math.sin(alpha)],
+ [-math.sin(alpha), math.cos(alpha)]
+ ]
+ )
+ return data, 'rectangle'
def data_zigzag(nb):
a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
y = a * 2.5 - 1.25
data = torch.cat((y, x), 1)
data = data @ torch.tensor([[1., -1.], [1., 1.]])
- return data
+ return data, 'zigzag'
def data_spiral(nb):
a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5)
y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5)
data = torch.cat((y, x), 1)
- return data
-
-######################################################################
+ return data, 'spiral'
-data = data_spiral(1000)
-# data = data_zigzag(1000)
-
-data = data - data.mean(0)
-
-batch_size, nb_epochs = 100, 1000
-optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
-criterion = nn.MSELoss()
-
-for e in range(nb_epochs):
- acc_loss = 0
- for input in data.split(batch_size):
- noise = input.new(input.size()).normal_(0, 0.1)
- output = model(input + noise)
- loss = criterion(output, input)
- acc_loss += loss.item()
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- if (e+1)%10 == 0: print(e+1, acc_loss)
+def data_penta(nb):
+ a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
+ x = a.cos()
+ y = a.sin()
+ data = torch.cat((y, x), 1)
+ data = data + data.new(data.size()).normal_(0, 0.05)
+ return data, 'penta'
######################################################################
-a = torch.linspace(-1.5, 1.5, 30)
-x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1)
-y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1)
-grid = torch.cat((y, x), 2).view(-1, 2)
+def train_model(data):
+ model = nn.Sequential(
+ nn.Linear(2, 100),
+ nn.ReLU(),
+ nn.Linear(100, 2)
+ )
+
+ batch_size, nb_epochs = 100, 1000
+ optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
+ criterion = nn.MSELoss()
+
+ for e in range(nb_epochs):
+ acc_loss = 0
+ for input in data.split(batch_size):
+ noise = input.new(input.size()).normal_(0, 0.1)
+ output = model(input + noise)
+ loss = criterion(output, input)
+ acc_loss += loss.item()
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ if (e+1)%100 == 0: print(e+1, acc_loss)
+
+ return model
-# Take the origins of the arrows on the part of grid closer than 0.1
-# from the data points
-dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0]
-origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)]
+######################################################################
-field = model(origins).detach() - origins
+def save_image(data_name, model, data):
+ a = torch.linspace(-1.5, 1.5, 30)
+ x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1)
+ y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1)
+ grid = torch.cat((y, x), 2).view(-1, 2)
-######################################################################
+ # Take the origins of the arrows on the part of the grid closer than
+ # sqrt(0.1) to the data points
+ dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0]
+ origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)]
-import matplotlib.pyplot as plt
+ field = model(origins).detach() - origins
-fig = plt.figure()
-ax = fig.add_subplot(1, 1, 1)
+ fig = plt.figure()
+ ax = fig.add_subplot(1, 1, 1)
-ax.axis('off')
-ax.set_xlim(-1.6, 1.6)
-ax.set_ylim(-1.6, 1.6)
-ax.set_aspect(1)
+ ax.axis('off')
+ ax.set_xlim(-1.6, 1.6)
+ ax.set_ylim(-1.6, 1.6)
+ ax.set_aspect(1)
-plot_field = ax.quiver(origins[:, 0].numpy(), origins[:, 1].numpy(),
- field[:, 0].numpy(), field[:, 1].numpy(),
- units = 'xy', scale = 1,
- width = 3e-3, headwidth = 25, headlength = 25)
+ plot_field = ax.quiver(
+ origins[:, 0].numpy(), origins[:, 1].numpy(),
+ field[:, 0].numpy(), field[:, 1].numpy(),
+ units = 'xy', scale = 1,
+ width = 3e-3, headwidth = 25, headlength = 25
+ )
-plot_data = ax.scatter(data[:, 0].numpy(), data[:, 1].numpy(), s = 1, color = 'tab:blue')
+ plot_data = ax.scatter(
+ data[:, 0].numpy(), data[:, 1].numpy(),
+ s = 1, color = 'tab:blue'
+ )
-fig.savefig('denoising_field.pdf', bbox_inches='tight')
+ filename = f'denoising_field_{data_name}.pdf'
+ print(f'Saving {filename}')
+ fig.savefig(filename, bbox_inches='tight')
######################################################################
+
+for data_source in [ data_rectangle, data_zigzag, data_spiral, data_penta ]:
+ data, data_name = data_source(1000)
+ data = data - data.mean(0)
+ model = train_model(data)
+ save_image(data_name, model, data)