X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=denoising-ae-field.py;h=effee19dba0d43d0fd8c00e92b4a0a504646ee1a;hb=fe7bae674b22a2ec4994a34be4a509b1a0d0ba72;hp=175f344cd7318c76cd3ed70ff52aa6a515d2dbbc;hpb=0d0635ed4e6836ef2c48cd59fe3d25f7969e7bcf;p=pytorch.git diff --git a/denoising-ae-field.py b/denoising-ae-field.py index 175f344..effee19 100755 --- a/denoising-ae-field.py +++ b/denoising-ae-field.py @@ -13,7 +13,7 @@ model = nn.Sequential( nn.Linear(100, 2) ) -############################################################ +###################################################################### def data_zigzag(nb): a = torch.empty(nb).uniform_(0, 1).view(-1, 1) @@ -31,10 +31,19 @@ def data_spiral(nb): data = torch.cat((y, x), 1) return data +def data_penta(nb): + a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1) + x = a.cos() + y = a.sin() + data = torch.cat((y, x), 1) + data = data + data.new(data.size()).normal_(0, 0.05) + return data + ###################################################################### data = data_spiral(1000) # data = data_zigzag(1000) +# data = data_penta(1000) data = data - data.mean(0) @@ -80,12 +89,17 @@ ax.set_xlim(-1.6, 1.6) ax.set_ylim(-1.6, 1.6) ax.set_aspect(1) -plot_field = ax.quiver(origins[:, 0].numpy(), origins[:, 1].numpy(), - field[:, 0].numpy(), field[:, 1].numpy(), - units = 'xy', scale = 1, - width = 3e-3, headwidth = 25, headlength = 25) +plot_field = ax.quiver( + origins[:, 0].numpy(), origins[:, 1].numpy(), + field[:, 0].numpy(), field[:, 1].numpy(), + units = 'xy', scale = 1, + width = 3e-3, headwidth = 25, headlength = 25 +) -plot_data = ax.scatter(data[:, 0].numpy(), data[:, 1].numpy(), s = 1, color = 'tab:blue') +plot_data = ax.scatter( + data[:, 0].numpy(), data[:, 1].numpy(), + s = 1, color = 'tab:blue' +) fig.savefig('denoising_field.pdf', bbox_inches='tight')