-data = data_spiral(1000)
-# data = data_zigzag(1000)
-
-data = data - data.mean(0)
-
-batch_size, nb_epochs = 100, 1000
-optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
-criterion = nn.MSELoss()
-
-for e in range(nb_epochs):
- acc_loss = 0
- for input in data.split(batch_size):
- noise = input.new(input.size()).normal_(0, 0.1)
- output = model(input + noise)
- loss = criterion(output, input)
- acc_loss += loss.item()
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- if (e+1)%10 == 0: print(e+1, acc_loss)
+def data_penta(nb):
+ a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
+ x = a.cos()
+ y = a.sin()
+ data = torch.cat((y, x), 1)
+ data = data + data.new(data.size()).normal_(0, 0.05)
+ return data, 'penta'