X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=denoising-ae-field.py;h=effee19dba0d43d0fd8c00e92b4a0a504646ee1a;hp=cc4af814c75e95384fbb62629a2cf1673a37e6ad;hb=fe7bae674b22a2ec4994a34be4a509b1a0d0ba72;hpb=7816075cd507d14ef09df13afbc4be525bedd08c diff --git a/denoising-ae-field.py b/denoising-ae-field.py index cc4af81..effee19 100755 --- a/denoising-ae-field.py +++ b/denoising-ae-field.py @@ -31,10 +31,19 @@ def data_spiral(nb): data = torch.cat((y, x), 1) return data +def data_penta(nb): + a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1) + x = a.cos() + y = a.sin() + data = torch.cat((y, x), 1) + data = data + data.new(data.size()).normal_(0, 0.05) + return data + ###################################################################### -# data = data_spiral(1000) -data = data_zigzag(1000) +data = data_spiral(1000) +# data = data_zigzag(1000) +# data = data_penta(1000) data = data - data.mean(0)