def sample_gaussian_mixture(nb):
p, std = 0.3, 0.2
- result = torch.empty(nb, 1).normal_(0, std)
+ result = torch.randn(nb, 1) * std
result = result + torch.sign(torch.rand(result.size()) - p) / 2
return result
+def sample_ramp(nb):
+ result = torch.min(torch.rand(nb, 1), torch.rand(nb, 1))
+ return result
+
def sample_two_discs(nb):
a = torch.rand(nb) * math.pi * 2
b = torch.rand(nb).sqrt()
def sample_disc_grid(nb):
a = torch.rand(nb) * math.pi * 2
b = torch.rand(nb).sqrt()
- q = torch.randint(5, (nb,)) / 2.5 - 2 / 2.5
- r = torch.randint(5, (nb,)) / 2.5 - 2 / 2.5
+ N = 4
+ q = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
+ r = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
b = b * 0.1
result = torch.empty(nb, 2)
result[:, 0] = a.cos() * b + q
samplers = {
'gaussian_mixture': sample_gaussian_mixture,
+ 'ramp': sample_ramp,
'two_discs': sample_two_discs,
'disc_grid': sample_disc_grid,
'spiral': sample_spiral,
# Model
if train_input.dim() == 2:
- nh = 64
+ nh = 256
model = nn.Sequential(
nn.Linear(train_input.size(1) + 1, nh),
model.to(device)
+print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}')
+
######################################################################
# Train
ema.step()
- if k%10 == 0: print(f'{k} {acc_loss / train_input.size(0)}')
+ print(f'{k} {acc_loss / train_input.size(0)}')
ema.copy()
x = generate((1000, 2), model)
- ax.set_xlim(-1.25, 1.25)
- ax.set_ylim(-1.25, 1.25)
+ ax.set_xlim(-1.5, 1.5)
+ ax.set_ylim(-1.5, 1.5)
ax.set(aspect = 1)
-
- d = train_input[:x.size(0)].detach().to('cpu').numpy()
- ax.scatter(d[:, 0], d[:, 1],
- color = 'lightblue', label = 'Train')
+ ax.spines.right.set_visible(False)
+ ax.spines.top.set_visible(False)
d = x.detach().to('cpu').numpy()
ax.scatter(d[:, 0], d[:, 1],
facecolors = 'none', color = 'red', label = 'Synthesis')
+ d = train_input[:x.size(0)].detach().to('cpu').numpy()
+ ax.scatter(d[:, 0], d[:, 1],
+ s = 1.0, color = 'blue', label = 'Train')
+
ax.legend(frameon = False, loc = 2)
filename = f'diffusion_{args.data}.pdf'