Update.
[pytorch.git] / minidiffusion.py
index e1f6abd..066cbbb 100755 (executable)
@@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
 
 import torch, torchvision
 from torch import nn
+from torch.nn import functional as F
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
@@ -65,12 +66,14 @@ def sample_mnist(nb):
     return result
 
 samplers = {
-    'gaussian_mixture': sample_gaussian_mixture,
-    'ramp': sample_ramp,
-    'two_discs': sample_two_discs,
-    'disc_grid': sample_disc_grid,
-    'spiral': sample_spiral,
-    'mnist': sample_mnist,
+    f.__name__.removeprefix('sample_') : f for f in [
+        sample_gaussian_mixture,
+        sample_ramp,
+        sample_two_discs,
+        sample_disc_grid,
+        sample_spiral,
+        sample_mnist,
+    ]
 }
 
 ######################################################################
@@ -113,6 +116,9 @@ parser.add_argument('--data',
                     type = str, default = 'gaussian_mixture',
                     help = f'Toy data-set to use: {data_list}')
 
+parser.add_argument('--no_window',
+                    action='store_true', default = False)
+
 args = parser.parse_args()
 
 if args.seed >= 0:
@@ -146,6 +152,20 @@ class EMA:
 
 ######################################################################
 
+# Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an
+# additional dimension / channel
+
+class TimeAppender(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, u):
+        x, t = u
+        if not torch.is_tensor(t):
+            t = x.new_full((x.size(0),), t)
+        t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:,:1])
+        return torch.cat((x, t), 1)
+
 class ConvNet(nn.Module):
     def __init__(self, in_channels, out_channels):
         super().__init__()
@@ -153,7 +173,8 @@ class ConvNet(nn.Module):
         ks, nc = 5, 64
 
         self.core = nn.Sequential(
-            nn.Conv2d(in_channels, nc, ks, padding = ks//2),
+            TimeAppender(),
+            nn.Conv2d(in_channels + 1, nc, ks, padding = ks//2),
             nn.ReLU(),
             nn.Conv2d(nc, nc, ks, padding = ks//2),
             nn.ReLU(),
@@ -166,8 +187,8 @@ class ConvNet(nn.Module):
             nn.Conv2d(nc, out_channels, ks, padding = ks//2),
         )
 
-    def forward(self, x):
-        return self.core(x)
+    def forward(self, u):
+        return self.core(u)
 
 ######################################################################
 # Data
@@ -187,6 +208,7 @@ if train_input.dim() == 2:
     nh = 256
 
     model = nn.Sequential(
+        TimeAppender(),
         nn.Linear(train_input.size(1) + 1, nh),
         nn.ReLU(),
         nn.Linear(nh, nh),
@@ -198,7 +220,7 @@ if train_input.dim() == 2:
 
 elif train_input.dim() == 4:
 
-    model = ConvNet(train_input.size(1) + 1, train_input.size(1))
+    model = ConvNet(train_input.size(1), train_input.size(1))
 
 model.to(device)
 
@@ -207,15 +229,17 @@ print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}')
 ######################################################################
 # Generate
 
-def generate(size, alpha, alpha_bar, sigma, model):
+def generate(size, T, alpha, alpha_bar, sigma, model, train_mean, train_std):
+
     with torch.no_grad():
+
         x = torch.randn(size, device = device)
 
         for t in range(T-1, -1, -1):
+            output = model((x, t / (T - 1) - 0.5))
             z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
-            input = torch.cat((x, torch.full_like(x[:,:1], t / (T - 1) - 0.5)), 1)
             x = 1/torch.sqrt(alpha[t]) \
-                * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * model(input)) \
+                * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * output) \
                 + sigma[t] * z
 
         x = x * train_std + train_mean
@@ -243,8 +267,8 @@ for k in range(args.nb_epochs):
         t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device)
         eps = torch.randn_like(x0)
         xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
-        input = torch.cat((xt, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1)
-        loss = (eps - model(input)).pow(2).mean()
+        output = model((xt, t / (T - 1) - 0.5))
+        loss = (eps - output).pow(2).mean()
         acc_loss += loss.item() * x0.size(0)
 
         optimizer.zero_grad()
@@ -262,63 +286,103 @@ if ema is not None: ema.copy_to_model()
 
 model.eval()
 
-if train_input.dim() == 2:
+########################################
+# Nx1 -> histogram
+if train_input.dim() == 2 and train_input.size(1) == 1:
 
     fig = plt.figure()
+    fig.set_figheight(5)
+    fig.set_figwidth(8)
+
     ax = fig.add_subplot(1, 1, 1)
 
-    if train_input.size(1) == 1:
+    x = generate((10000, 1), T, alpha, alpha_bar, sigma,
+                 model, train_mean, train_std)
 
-        x = generate((10000, 1), alpha, alpha_bar, sigma, model)
+    ax.set_xlim(-1.25, 1.25)
+    ax.spines.right.set_visible(False)
+    ax.spines.top.set_visible(False)
 
-        ax.set_xlim(-1.25, 1.25)
-        ax.spines.right.set_visible(False)
-        ax.spines.top.set_visible(False)
+    d = train_input.flatten().detach().to('cpu').numpy()
+    ax.hist(d, 25, (-1, 1),
+            density = True,
+            histtype = 'bar', edgecolor = 'white', color = 'lightblue', label = 'Train')
 
-        d = train_input.flatten().detach().to('cpu').numpy()
-        ax.hist(d, 25, (-1, 1),
-                density = True,
-                histtype = 'stepfilled', color = 'lightblue', label = 'Train')
+    d = x.flatten().detach().to('cpu').numpy()
+    ax.hist(d, 25, (-1, 1),
+            density = True,
+            histtype = 'step', color = 'red', label = 'Synthesis')
 
-        d = x.flatten().detach().to('cpu').numpy()
-        ax.hist(d, 25, (-1, 1),
-                density = True,
-                histtype = 'step', color = 'red', label = 'Synthesis')
+    ax.legend(frameon = False, loc = 2)
 
-        ax.legend(frameon = False, loc = 2)
+    filename = f'minidiffusion_{args.data}.pdf'
+    print(f'saving {filename}')
+    fig.savefig(filename, bbox_inches='tight')
 
-    elif train_input.size(1) == 2:
+    if not args.no_window and hasattr(plt.get_current_fig_manager(), 'window'):
+        plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
+        plt.show()
 
-        x = generate((1000, 2), alpha, alpha_bar, sigma, model)
+########################################
+# Nx2 -> scatter plot
+elif train_input.dim() == 2 and train_input.size(1) == 2:
 
-        ax.set_xlim(-1.5, 1.5)
-        ax.set_ylim(-1.5, 1.5)
-        ax.set(aspect = 1)
-        ax.spines.right.set_visible(False)
-        ax.spines.top.set_visible(False)
+    fig = plt.figure()
+    fig.set_figheight(6)
+    fig.set_figwidth(6)
+
+    ax = fig.add_subplot(1, 1, 1)
 
-        d = x.detach().to('cpu').numpy()
-        ax.scatter(d[:, 0], d[:, 1],
-                   s = 2.0, color = 'red', label = 'Synthesis')
+    x = generate((1000, 2), T, alpha, alpha_bar, sigma,
+                 model, train_mean, train_std)
 
-        d = train_input[:x.size(0)].detach().to('cpu').numpy()
-        ax.scatter(d[:, 0], d[:, 1],
-                   s = 2.0, color = 'gray', label = 'Train')
+    ax.set_xlim(-1.5, 1.5)
+    ax.set_ylim(-1.5, 1.5)
+    ax.set(aspect = 1)
+    ax.spines.right.set_visible(False)
+    ax.spines.top.set_visible(False)
 
-        ax.legend(frameon = False, loc = 2)
+    d = train_input[:x.size(0)].detach().to('cpu').numpy()
+    ax.scatter(d[:, 0], d[:, 1],
+               s = 2.5, color = 'gray', label = 'Train')
 
-    filename = f'diffusion_{args.data}.pdf'
+    d = x.detach().to('cpu').numpy()
+    ax.scatter(d[:, 0], d[:, 1],
+               s = 2.0, color = 'red', label = 'Synthesis')
+
+    ax.legend(frameon = False, loc = 2)
+
+    filename = f'minidiffusion_{args.data}.pdf'
     print(f'saving {filename}')
     fig.savefig(filename, bbox_inches='tight')
 
-    if hasattr(plt.get_current_fig_manager(), 'window'):
+    if not args.no_window and hasattr(plt.get_current_fig_manager(), 'window'):
         plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
         plt.show()
 
+########################################
+# NxCxHxW -> image
 elif train_input.dim() == 4:
 
-    x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma, model)
-    x = 1 - x.clamp(min = 0, max = 255) / 255
-    torchvision.utils.save_image(x, f'diffusion_{args.data}.png', nrow = 16, pad_value = 0.8)
+    x = generate((128,) + train_input.size()[1:], T, alpha, alpha_bar, sigma,
+                 model, train_mean, train_std)
+
+    x = torchvision.utils.make_grid(x.clamp(min = 0, max = 255),
+                                    nrow = 16, padding = 1, pad_value = 64)
+    x = F.pad(x, pad = (2, 2, 2, 2), value = 64)[None]
+
+    t = torchvision.utils.make_grid(train_input[:128],
+                                    nrow = 16, padding = 1, pad_value = 64)
+    t = F.pad(t, pad = (2, 2, 2, 2), value = 64)[None]
+
+    result = 1 - torch.cat((t, x), 2) / 255
+
+    filename = f'minidiffusion_{args.data}.png'
+    print(f'saving {filename}')
+    torchvision.utils.save_image(result, filename)
+
+else:
+
+    print(f'cannot plot result of size {train_input.size()}')
 
 ######################################################################