--- /dev/null
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+nb = 100
+delta = 0.35
+x = torch.empty(nb).uniform_(0.0, delta)
+x += x.new_full(x.size(), 0.5).bernoulli() * (1 - delta)
+
+a = x * math.pi * 2 * 4
+b = x * math.pi * 2 * 3
+y = a.sin() + b
+
+x = x.view(-1, 1)
+y = y.view(-1, 1)
+
+######################################################################
+
+nh = 100
+
+model = nn.Sequential(nn.Linear(1, nh), nn.ReLU(),
+ nn.Linear(nh, nh), nn.ReLU(),
+ nn.Linear(nh, 1))
+
+criterion = nn.MSELoss()
+optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
+
+for k in range(10000):
+ loss = criterion(model(x), y)
+ if (k+1)%100 == 0: print(k+1, loss.item())
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+######################################################################
+
+import matplotlib.pyplot as plt
+
+fig, ax = plt.subplots()
+ax.scatter(x.numpy(), y.numpy())
+
+u = torch.linspace(0, 1, 100).view(-1, 1)
+ax.plot(u.numpy(), model(u).detach().numpy(), color = 'red')
+plt.show()
+
+######################################################################
type = int, default = 100,
help = 'Batch size')
+parser.add_argument('--independent', action = 'store_true',
+ help = 'Should the pair components be independent')
+
######################################################################
def entropy(target):
c = torch.cat(uc, 0)
perm = torch.randperm(a.size(0))
a = a[perm].contiguous()
+
+ if args.independent:
+ perm = torch.randperm(a.size(0))
b = b[perm].contiguous()
return a, b, c
b = a.new(a.size(0), 2)
b[:, 0].uniform_(0.0, 10.0)
b[:, 1].uniform_(0.0, 0.5)
- b[:, 1] += b[:, 0] + target.float()
+
+ if args.independent:
+ b[:, 1] += b[:, 0] + used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
+ else:
+ b[:, 1] += b[:, 0] + target.float()
return a, b, c
noise_level = 2e-2
ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
- # hb = torch.randint(args.nb_classes, (nb, ), device = device)
- hb = ha
+ if args.independent:
+ hb = torch.randint(args.nb_classes, (nb, ), device = device)
+ else:
+ hb = ha
pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
######################################################################
a, b, c = create_pairs()
for k in range(10):
- file = open(f'/tmp/train_{k:02d}.dat', 'w')
+ file = open(f'train_{k:02d}.dat', 'w')
for i in range(a.size(1)):
file.write(f'{a[k, i]:f} {b[k,i]:f}\n')
file.close()