From d29423af71d76b52cbdf04367dbfe1908a492786 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Mon, 23 Dec 2019 17:39:26 +0100 Subject: [PATCH] Added the rectangle. --- denoising-ae-field.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/denoising-ae-field.py b/denoising-ae-field.py index 47e6ab4..f96c23a 100755 --- a/denoising-ae-field.py +++ b/denoising-ae-field.py @@ -13,6 +13,19 @@ from torch import nn ###################################################################### +def data_rectangle(nb): + x = torch.rand(nb, 1) - 0.5 + y = torch.rand(nb, 1) * 2 - 1 + data = torch.cat((y, x), 1) + alpha = math.pi / 8 + data = data @ torch.tensor( + [ + [ math.cos(alpha), math.sin(alpha)], + [-math.sin(alpha), math.cos(alpha)] + ] + ) + return data, 'rectangle' + def data_zigzag(nb): a = torch.empty(nb).uniform_(0, 1).view(-1, 1) # zigzag @@ -105,7 +118,7 @@ def save_image(data_name, model, data): ###################################################################### -for data_source in [ data_zigzag, data_spiral, data_penta ]: +for data_source in [ data_rectangle, data_zigzag, data_spiral, data_penta ]: data, data_name = data_source(1000) data = data - data.mean(0) model = train_model(data) -- 2.39.5