Tried to make the source clearer, added the TimeAppender Module.
[pytorch.git] / minidiffusion.py
index 879b796..27842d9 100755 (executable)
@@ -105,7 +105,7 @@ parser.add_argument('--learning_rate',
 
 parser.add_argument('--ema_decay',
                     type = float, default = 0.9999,
-                    help = 'EMA decay, < 0 is no EMA')
+                    help = 'EMA decay, <= 0 is no EMA')
 
 data_list = ', '.join( [ str(k) for k in samplers ])
 
@@ -129,26 +129,37 @@ class EMA:
     def __init__(self, model, decay):
         self.model = model
         self.decay = decay
-        if self.decay < 0: return
-        self.ema = { }
+        self.mem = { }
         with torch.no_grad():
             for p in model.parameters():
-                self.ema[p] = p.clone()
+                self.mem[p] = p.clone()
 
     def step(self):
-        if self.decay < 0: return
         with torch.no_grad():
             for p in self.model.parameters():
-                self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p)
+                self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p)
 
-    def copy(self):
-        if self.decay < 0: return
+    def copy_to_model(self):
         with torch.no_grad():
             for p in self.model.parameters():
-                p.copy_(self.ema[p])
+                p.copy_(self.mem[p])
 
 ######################################################################
 
+# Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an
+# additional dimension / channel
+
+class TimeAppender(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, u):
+        x, t = u
+        if not torch.is_tensor(t):
+            t = x.new_full((x.size(0),), t)
+        t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:,:1])
+        return torch.cat((x, t), 1)
+
 class ConvNet(nn.Module):
     def __init__(self, in_channels, out_channels):
         super().__init__()
@@ -156,7 +167,8 @@ class ConvNet(nn.Module):
         ks, nc = 5, 64
 
         self.core = nn.Sequential(
-            nn.Conv2d(in_channels, nc, ks, padding = ks//2),
+            TimeAppender(),
+            nn.Conv2d(in_channels + 1, nc, ks, padding = ks//2),
             nn.ReLU(),
             nn.Conv2d(nc, nc, ks, padding = ks//2),
             nn.ReLU(),
@@ -169,8 +181,8 @@ class ConvNet(nn.Module):
             nn.Conv2d(nc, out_channels, ks, padding = ks//2),
         )
 
-    def forward(self, x):
-        return self.core(x)
+    def forward(self, u):
+        return self.core(u)
 
 ######################################################################
 # Data
@@ -190,6 +202,7 @@ if train_input.dim() == 2:
     nh = 256
 
     model = nn.Sequential(
+        TimeAppender(),
         nn.Linear(train_input.size(1) + 1, nh),
         nn.ReLU(),
         nn.Linear(nh, nh),
@@ -201,7 +214,7 @@ if train_input.dim() == 2:
 
 elif train_input.dim() == 4:
 
-    model = ConvNet(train_input.size(1) + 1, train_input.size(1))
+    model = ConvNet(train_input.size(1), train_input.size(1))
 
 model.to(device)
 
@@ -210,15 +223,17 @@ print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}')
 ######################################################################
 # Generate
 
-def generate(size, alpha, alpha_bar, sigma, model):
+def generate(size, alpha, alpha_bar, sigma, model, train_mean, train_std):
+
     with torch.no_grad():
+
         x = torch.randn(size, device = device)
 
         for t in range(T-1, -1, -1):
+            output = model((x, t / (T - 1) - 0.5))
             z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
-            input = torch.cat((x, torch.full_like(x[:,:1], t / (T - 1) - 0.5)), 1)
             x = 1/torch.sqrt(alpha[t]) \
-                * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * model(input)) \
+                * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * output) \
                 + sigma[t] * z
 
         x = x * train_std + train_mean
@@ -234,7 +249,7 @@ alpha = 1 - beta
 alpha_bar = alpha.log().cumsum(0).exp()
 sigma = beta.sqrt()
 
-ema = EMA(model, decay = args.ema_decay)
+ema = EMA(model, decay = args.ema_decay) if args.ema_decay > 0 else None
 
 for k in range(args.nb_epochs):
 
@@ -245,20 +260,20 @@ for k in range(args.nb_epochs):
         x0 = (x0 - train_mean) / train_std
         t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device)
         eps = torch.randn_like(x0)
-        input = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
-        input = torch.cat((input, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1)
-        loss = (eps - model(input)).pow(2).mean()
+        xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
+        output = model((xt, t / (T - 1) - 0.5))
+        loss = (eps - output).pow(2).mean()
         acc_loss += loss.item() * x0.size(0)
 
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
 
-        ema.step()
+        if ema is not None: ema.step()
 
     print(f'{k} {acc_loss / train_input.size(0)}')
 
-ema.copy()
+if ema is not None: ema.copy_to_model()
 
 ######################################################################
 # Plot
@@ -270,9 +285,11 @@ if train_input.dim() == 2:
     fig = plt.figure()
     ax = fig.add_subplot(1, 1, 1)
 
+    # Nx1 -> histogram
     if train_input.size(1) == 1:
 
-        x = generate((10000, 1), alpha, alpha_bar, sigma, model)
+        x = generate((10000, 1), alpha, alpha_bar, sigma,
+                     model, train_mean, train_std)
 
         ax.set_xlim(-1.25, 1.25)
         ax.spines.right.set_visible(False)
@@ -290,9 +307,11 @@ if train_input.dim() == 2:
 
         ax.legend(frameon = False, loc = 2)
 
+    # Nx2 -> scatter plot
     elif train_input.size(1) == 2:
 
-        x = generate((1000, 2), alpha, alpha_bar, sigma, model)
+        x = generate((1000, 2), alpha, alpha_bar, sigma,
+                     model, train_mean, train_std)
 
         ax.set_xlim(-1.5, 1.5)
         ax.set_ylim(-1.5, 1.5)
@@ -318,10 +337,15 @@ if train_input.dim() == 2:
         plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
         plt.show()
 
+# NxCxHxW -> image
 elif train_input.dim() == 4:
 
-    x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma, model)
+    x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma,
+                 model, train_mean, train_std)
     x = 1 - x.clamp(min = 0, max = 255) / 255
-    torchvision.utils.save_image(x, f'diffusion_{args.data}.png', nrow = 16, pad_value = 0.8)
+
+    filename = f'diffusion_{args.data}.png'
+    print(f'saving {filename}')
+    torchvision.utils.save_image(x, filename, nrow = 16, pad_value = 0.8)
 
 ######################################################################