######################################################################
+def data_rectangle(nb):
+ x = torch.rand(nb, 1) - 0.5
+ y = torch.rand(nb, 1) * 2 - 1
+ data = torch.cat((y, x), 1)
+ alpha = math.pi / 8
+ data = data @ torch.tensor(
+ [
+ [ math.cos(alpha), math.sin(alpha)],
+ [-math.sin(alpha), math.cos(alpha)]
+ ]
+ )
+ return data, 'rectangle'
+
def data_zigzag(nb):
a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
# zigzag
######################################################################
-for data_source in [ data_zigzag, data_spiral, data_penta ]:
+for data_source in [ data_rectangle, data_zigzag, data_spiral, data_penta ]:
data, data_name = data_source(1000)
data = data - data.mean(0)
model = train_model(data)