Update.
authorFrancois Fleuret <francois@fleuret.org>
Sun, 14 Aug 2022 07:50:41 +0000 (09:50 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Sun, 14 Aug 2022 07:50:41 +0000 (09:50 +0200)
minidiffusion.py

index 2c54d19..6fd8564 100755 (executable)
@@ -18,10 +18,14 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 def sample_gaussian_mixture(nb):
     p, std = 0.3, 0.2
-    result = torch.empty(nb, 1).normal_(0, std)
+    result = torch.randn(nb, 1) * std
     result = result + torch.sign(torch.rand(result.size()) - p) / 2
     return result
 
+def sample_ramp(nb):
+    result = torch.min(torch.rand(nb, 1), torch.rand(nb, 1))
+    return result
+
 def sample_two_discs(nb):
     a = torch.rand(nb) * math.pi * 2
     b = torch.rand(nb).sqrt()
@@ -35,8 +39,9 @@ def sample_two_discs(nb):
 def sample_disc_grid(nb):
     a = torch.rand(nb) * math.pi * 2
     b = torch.rand(nb).sqrt()
-    q = torch.randint(5, (nb,)) / 2.5 - 2 / 2.5
-    r = torch.randint(5, (nb,)) / 2.5 - 2 / 2.5
+    N = 4
+    q = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
+    r = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2)
     b = b * 0.1
     result = torch.empty(nb, 2)
     result[:, 0] = a.cos() * b + q
@@ -59,6 +64,7 @@ def sample_mnist(nb):
 
 samplers = {
     'gaussian_mixture': sample_gaussian_mixture,
+    'ramp': sample_ramp,
     'two_discs': sample_two_discs,
     'disc_grid': sample_disc_grid,
     'spiral': sample_spiral,
@@ -179,7 +185,7 @@ train_mean, train_std = train_input.mean(), train_input.std()
 # Model
 
 if train_input.dim() == 2:
-    nh = 64
+    nh = 256
 
     model = nn.Sequential(
         nn.Linear(train_input.size(1) + 1, nh),
@@ -197,6 +203,8 @@ elif train_input.dim() == 4:
 
 model.to(device)
 
+print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}')
+
 ######################################################################
 # Train
 
@@ -228,7 +236,7 @@ for k in range(args.nb_epochs):
 
         ema.step()
 
-    if k%10 == 0: print(f'{k} {acc_loss / train_input.size(0)}')
+    print(f'{k} {acc_loss / train_input.size(0)}')
 
 ema.copy()
 
@@ -281,18 +289,20 @@ if train_input.dim() == 2:
 
         x = generate((1000, 2), model)
 
-        ax.set_xlim(-1.25, 1.25)
-        ax.set_ylim(-1.25, 1.25)
+        ax.set_xlim(-1.5, 1.5)
+        ax.set_ylim(-1.5, 1.5)
         ax.set(aspect = 1)
-
-        d = train_input[:x.size(0)].detach().to('cpu').numpy()
-        ax.scatter(d[:, 0], d[:, 1],
-                   color = 'lightblue', label = 'Train')
+        ax.spines.right.set_visible(False)
+        ax.spines.top.set_visible(False)
 
         d = x.detach().to('cpu').numpy()
         ax.scatter(d[:, 0], d[:, 1],
                    facecolors = 'none', color = 'red', label = 'Synthesis')
 
+        d = train_input[:x.size(0)].detach().to('cpu').numpy()
+        ax.scatter(d[:, 0], d[:, 1],
+                   s = 1.0, color = 'blue', label = 'Train')
+
         ax.legend(frameon = False, loc = 2)
 
     filename = f'diffusion_{args.data}.pdf'