#!/usr/bin/env python
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
import math
import matplotlib.pyplot as plt
######################################################################
+def data_rectangle(nb):
+ x = torch.rand(nb, 1) - 0.5
+ y = torch.rand(nb, 1) * 2 - 1
+ data = torch.cat((y, x), 1)
+ alpha = math.pi / 8
+ data = data @ torch.tensor(
+ [
+ [ math.cos(alpha), math.sin(alpha)],
+ [-math.sin(alpha), math.cos(alpha)]
+ ]
+ )
+ return data, 'rectangle'
+
def data_zigzag(nb):
a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
# zigzag
######################################################################
-def save_image(data, data_name, model):
+def save_image(data_name, model, data):
a = torch.linspace(-1.5, 1.5, 30)
x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1)
y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1)
######################################################################
-for data_source in [ data_zigzag, data_spiral, data_penta ]:
+for data_source in [ data_rectangle, data_zigzag, data_spiral, data_penta ]:
data, data_name = data_source(1000)
data = data - data.mean(0)
model = train_model(data)
- save_image(data, data_name, model)
+ save_image(data_name, model, data)