nn.Linear(100, 2)
)
-############################################################
+######################################################################
def data_zigzag(nb):
a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
data = torch.cat((y, x), 1)
return data
+def data_penta(nb):
+ a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
+ x = a.cos()
+ y = a.sin()
+ data = torch.cat((y, x), 1)
+ data = data + data.new(data.size()).normal_(0, 0.05)
+ return data
+
######################################################################
data = data_spiral(1000)
# data = data_zigzag(1000)
+# data = data_penta(1000)
data = data - data.mean(0)
ax.set_ylim(-1.6, 1.6)
ax.set_aspect(1)
-plot_field = ax.quiver(origins[:, 0].numpy(), origins[:, 1].numpy(),
- field[:, 0].numpy(), field[:, 1].numpy(),
- units = 'xy', scale = 1,
- width = 3e-3, headwidth = 25, headlength = 25)
+plot_field = ax.quiver(
+ origins[:, 0].numpy(), origins[:, 1].numpy(),
+ field[:, 0].numpy(), field[:, 1].numpy(),
+ units = 'xy', scale = 1,
+ width = 3e-3, headwidth = 25, headlength = 25
+)
-plot_data = ax.scatter(data[:, 0].numpy(), data[:, 1].numpy(), s = 1, color = 'tab:blue')
+plot_data = ax.scatter(
+ data[:, 0].numpy(), data[:, 1].numpy(),
+ s = 1, color = 'tab:blue'
+)
fig.savefig('denoising_field.pdf', bbox_inches='tight')