Update.
[pytorch.git] / minidiffusion.py
index 879b796..841dd2a 100755 (executable)
@@ -105,7 +105,7 @@ parser.add_argument('--learning_rate',
 
 parser.add_argument('--ema_decay',
                     type = float, default = 0.9999,
-                    help = 'EMA decay, < 0 is no EMA')
+                    help = 'EMA decay, <= 0 is no EMA')
 
 data_list = ', '.join( [ str(k) for k in samplers ])
 
@@ -129,23 +129,20 @@ class EMA:
     def __init__(self, model, decay):
         self.model = model
         self.decay = decay
-        if self.decay < 0: return
-        self.ema = { }
+        self.mem = { }
         with torch.no_grad():
             for p in model.parameters():
-                self.ema[p] = p.clone()
+                self.mem[p] = p.clone()
 
     def step(self):
-        if self.decay < 0: return
         with torch.no_grad():
             for p in self.model.parameters():
-                self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p)
+                self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p)
 
-    def copy(self):
-        if self.decay < 0: return
+    def copy_to_model(self):
         with torch.no_grad():
             for p in self.model.parameters():
-                p.copy_(self.ema[p])
+                p.copy_(self.mem[p])
 
 ######################################################################
 
@@ -210,8 +207,10 @@ print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}')
 ######################################################################
 # Generate
 
-def generate(size, alpha, alpha_bar, sigma, model):
+def generate(size, alpha, alpha_bar, sigma, model, train_mean, train_std):
+
     with torch.no_grad():
+
         x = torch.randn(size, device = device)
 
         for t in range(T-1, -1, -1):
@@ -234,7 +233,7 @@ alpha = 1 - beta
 alpha_bar = alpha.log().cumsum(0).exp()
 sigma = beta.sqrt()
 
-ema = EMA(model, decay = args.ema_decay)
+ema = EMA(model, decay = args.ema_decay) if args.ema_decay > 0 else None
 
 for k in range(args.nb_epochs):
 
@@ -245,8 +244,8 @@ for k in range(args.nb_epochs):
         x0 = (x0 - train_mean) / train_std
         t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device)
         eps = torch.randn_like(x0)
-        input = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
-        input = torch.cat((input, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1)
+        xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps
+        input = torch.cat((xt, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1)
         loss = (eps - model(input)).pow(2).mean()
         acc_loss += loss.item() * x0.size(0)
 
@@ -254,11 +253,11 @@ for k in range(args.nb_epochs):
         loss.backward()
         optimizer.step()
 
-        ema.step()
+        if ema is not None: ema.step()
 
     print(f'{k} {acc_loss / train_input.size(0)}')
 
-ema.copy()
+if ema is not None: ema.copy_to_model()
 
 ######################################################################
 # Plot
@@ -272,7 +271,8 @@ if train_input.dim() == 2:
 
     if train_input.size(1) == 1:
 
-        x = generate((10000, 1), alpha, alpha_bar, sigma, model)
+        x = generate((10000, 1), alpha, alpha_bar, sigma,
+                     model, train_mean, train_std)
 
         ax.set_xlim(-1.25, 1.25)
         ax.spines.right.set_visible(False)
@@ -292,7 +292,8 @@ if train_input.dim() == 2:
 
     elif train_input.size(1) == 2:
 
-        x = generate((1000, 2), alpha, alpha_bar, sigma, model)
+        x = generate((1000, 2), alpha, alpha_bar, sigma,
+                     model, train_mean, train_std)
 
         ax.set_xlim(-1.5, 1.5)
         ax.set_ylim(-1.5, 1.5)
@@ -320,7 +321,8 @@ if train_input.dim() == 2:
 
 elif train_input.dim() == 4:
 
-    x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma, model)
+    x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma,
+                 model, train_mean, train_std)
     x = 1 - x.clamp(min = 0, max = 255) / 255
     torchvision.utils.save_image(x, f'diffusion_{args.data}.png', nrow = 16, pad_value = 0.8)