X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=minidiffusion.py;h=27842d9e0d785f4b1dac433702f8a7782db69444;hb=989834728791abe50c14aad4294e1ef10d9bbf35;hp=879b7964825e2a89b20a06c9228311d448f3f380;hpb=560b7d51f52c7328e9d87ce717dacc4da7977de7;p=pytorch.git diff --git a/minidiffusion.py b/minidiffusion.py index 879b796..27842d9 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -105,7 +105,7 @@ parser.add_argument('--learning_rate', parser.add_argument('--ema_decay', type = float, default = 0.9999, - help = 'EMA decay, < 0 is no EMA') + help = 'EMA decay, <= 0 is no EMA') data_list = ', '.join( [ str(k) for k in samplers ]) @@ -129,26 +129,37 @@ class EMA: def __init__(self, model, decay): self.model = model self.decay = decay - if self.decay < 0: return - self.ema = { } + self.mem = { } with torch.no_grad(): for p in model.parameters(): - self.ema[p] = p.clone() + self.mem[p] = p.clone() def step(self): - if self.decay < 0: return with torch.no_grad(): for p in self.model.parameters(): - self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p) + self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p) - def copy(self): - if self.decay < 0: return + def copy_to_model(self): with torch.no_grad(): for p in self.model.parameters(): - p.copy_(self.ema[p]) + p.copy_(self.mem[p]) ###################################################################### +# Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an +# additional dimension / channel + +class TimeAppender(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, u): + x, t = u + if not torch.is_tensor(t): + t = x.new_full((x.size(0),), t) + t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:,:1]) + return torch.cat((x, t), 1) + class ConvNet(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() @@ -156,7 +167,8 @@ class ConvNet(nn.Module): ks, nc = 5, 64 self.core = nn.Sequential( - nn.Conv2d(in_channels, nc, ks, padding = ks//2), + TimeAppender(), + nn.Conv2d(in_channels + 1, nc, ks, padding = ks//2), nn.ReLU(), nn.Conv2d(nc, nc, ks, padding = ks//2), nn.ReLU(), @@ -169,8 +181,8 @@ class ConvNet(nn.Module): nn.Conv2d(nc, out_channels, ks, padding = ks//2), ) - def forward(self, x): - return self.core(x) + def forward(self, u): + return self.core(u) ###################################################################### # Data @@ -190,6 +202,7 @@ if train_input.dim() == 2: nh = 256 model = nn.Sequential( + TimeAppender(), nn.Linear(train_input.size(1) + 1, nh), nn.ReLU(), nn.Linear(nh, nh), @@ -201,7 +214,7 @@ if train_input.dim() == 2: elif train_input.dim() == 4: - model = ConvNet(train_input.size(1) + 1, train_input.size(1)) + model = ConvNet(train_input.size(1), train_input.size(1)) model.to(device) @@ -210,15 +223,17 @@ print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}') ###################################################################### # Generate -def generate(size, alpha, alpha_bar, sigma, model): +def generate(size, alpha, alpha_bar, sigma, model, train_mean, train_std): + with torch.no_grad(): + x = torch.randn(size, device = device) for t in range(T-1, -1, -1): + output = model((x, t / (T - 1) - 0.5)) z = torch.zeros_like(x) if t == 0 else torch.randn_like(x) - input = torch.cat((x, torch.full_like(x[:,:1], t / (T - 1) - 0.5)), 1) x = 1/torch.sqrt(alpha[t]) \ - * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * model(input)) \ + * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * output) \ + sigma[t] * z x = x * train_std + train_mean @@ -234,7 +249,7 @@ alpha = 1 - beta alpha_bar = alpha.log().cumsum(0).exp() sigma = beta.sqrt() -ema = EMA(model, decay = args.ema_decay) +ema = EMA(model, decay = args.ema_decay) if args.ema_decay > 0 else None for k in range(args.nb_epochs): @@ -245,20 +260,20 @@ for k in range(args.nb_epochs): x0 = (x0 - train_mean) / train_std t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device) eps = torch.randn_like(x0) - input = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps - input = torch.cat((input, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1) - loss = (eps - model(input)).pow(2).mean() + xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps + output = model((xt, t / (T - 1) - 0.5)) + loss = (eps - output).pow(2).mean() acc_loss += loss.item() * x0.size(0) optimizer.zero_grad() loss.backward() optimizer.step() - ema.step() + if ema is not None: ema.step() print(f'{k} {acc_loss / train_input.size(0)}') -ema.copy() +if ema is not None: ema.copy_to_model() ###################################################################### # Plot @@ -270,9 +285,11 @@ if train_input.dim() == 2: fig = plt.figure() ax = fig.add_subplot(1, 1, 1) + # Nx1 -> histogram if train_input.size(1) == 1: - x = generate((10000, 1), alpha, alpha_bar, sigma, model) + x = generate((10000, 1), alpha, alpha_bar, sigma, + model, train_mean, train_std) ax.set_xlim(-1.25, 1.25) ax.spines.right.set_visible(False) @@ -290,9 +307,11 @@ if train_input.dim() == 2: ax.legend(frameon = False, loc = 2) + # Nx2 -> scatter plot elif train_input.size(1) == 2: - x = generate((1000, 2), alpha, alpha_bar, sigma, model) + x = generate((1000, 2), alpha, alpha_bar, sigma, + model, train_mean, train_std) ax.set_xlim(-1.5, 1.5) ax.set_ylim(-1.5, 1.5) @@ -318,10 +337,15 @@ if train_input.dim() == 2: plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) plt.show() +# NxCxHxW -> image elif train_input.dim() == 4: - x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma, model) + x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma, + model, train_mean, train_std) x = 1 - x.clamp(min = 0, max = 255) / 255 - torchvision.utils.save_image(x, f'diffusion_{args.data}.png', nrow = 16, pad_value = 0.8) + + filename = f'diffusion_{args.data}.png' + print(f'saving {filename}') + torchvision.utils.save_image(x, filename, nrow = 16, pad_value = 0.8) ######################################################################