Added the 5 cluster data-set.
[pytorch.git] / denoising-ae-field.py
index 67d2415..effee19 100755 (executable)
@@ -13,7 +13,7 @@ model = nn.Sequential(
     nn.Linear(100, 2)
 )
 
-############################################################
+######################################################################
 
 def data_zigzag(nb):
     a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
@@ -31,10 +31,19 @@ def data_spiral(nb):
     data = torch.cat((y, x), 1)
     return data
 
+def data_penta(nb):
+    a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
+    x = a.cos()
+    y = a.sin()
+    data = torch.cat((y, x), 1)
+    data = data + data.new(data.size()).normal_(0, 0.05)
+    return data
+
 ######################################################################
 
-# data = data_spiral(1000)
-data = data_zigzag(1000)
+data = data_spiral(1000)
+# data = data_zigzag(1000)
+# data = data_penta(1000)
 
 data = data - data.mean(0)
 
@@ -80,12 +89,17 @@ ax.set_xlim(-1.6, 1.6)
 ax.set_ylim(-1.6, 1.6)
 ax.set_aspect(1)
 
-plot_field = ax.quiver(origins[:, 0].numpy(), origins[:, 1].numpy(),
-                       field[:, 0].numpy(), field[:, 1].numpy(),
-                       units = 'xy', scale = 1,
-                       width = 3e-3, headwidth = 25, headlength = 25)
+plot_field = ax.quiver(
+    origins[:, 0].numpy(), origins[:, 1].numpy(),
+    field[:, 0].numpy(), field[:, 1].numpy(),
+    units = 'xy', scale = 1,
+    width = 3e-3, headwidth = 25, headlength = 25
+)
 
-plot_data = ax.scatter(data[:, 0].numpy(), data[:, 1].numpy(), s = 1, color = 'tab:blue')
+plot_data = ax.scatter(
+    data[:, 0].numpy(), data[:, 1].numpy(),
+    s = 1, color = 'tab:blue'
+)
 
 fig.savefig('denoising_field.pdf', bbox_inches='tight')