X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=minidiffusion.py;h=2c54d196062ee1775385335d67c03ba29a34b3ca;hb=a3c7617d0b5770edf6030502e4eac477a7218820;hp=cbdb1425b090d7f13b33eaa6779e8a9177821d70;hpb=bc937c74ad8cbeede2c2ae97cda72eaa3e9bb4f3;p=pytorch.git diff --git a/minidiffusion.py b/minidiffusion.py index cbdb142..2c54d19 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -5,85 +5,307 @@ # Written by Francois Fleuret +import math, argparse + import matplotlib.pyplot as plt -import torch + +import torch, torchvision from torch import nn +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + ###################################################################### -def sample_phi(nb): +def sample_gaussian_mixture(nb): p, std = 0.3, 0.2 - result = torch.empty(nb).normal_(0, std) + result = torch.empty(nb, 1).normal_(0, std) result = result + torch.sign(torch.rand(result.size()) - p) / 2 return result +def sample_two_discs(nb): + a = torch.rand(nb) * math.pi * 2 + b = torch.rand(nb).sqrt() + q = (torch.rand(nb) <= 0.5).long() + b = b * (0.3 + 0.2 * q) + result = torch.empty(nb, 2) + result[:, 0] = a.cos() * b - 0.5 + q + result[:, 1] = a.sin() * b - 0.5 + q + return result + +def sample_disc_grid(nb): + a = torch.rand(nb) * math.pi * 2 + b = torch.rand(nb).sqrt() + q = torch.randint(5, (nb,)) / 2.5 - 2 / 2.5 + r = torch.randint(5, (nb,)) / 2.5 - 2 / 2.5 + b = b * 0.1 + result = torch.empty(nb, 2) + result[:, 0] = a.cos() * b + q + result[:, 1] = a.sin() * b + r + return result + +def sample_spiral(nb): + u = torch.rand(nb) + rho = u * 0.65 + 0.25 + torch.rand(nb) * 0.15 + theta = u * math.pi * 3 + result = torch.empty(nb, 2) + result[:, 0] = theta.cos() * rho + result[:, 1] = theta.sin() * rho + return result + +def sample_mnist(nb): + train_set = torchvision.datasets.MNIST(root = './data/', train = True, download = True) + result = train_set.data[:nb].to(device).view(-1, 1, 28, 28).float() + return result + +samplers = { + 'gaussian_mixture': sample_gaussian_mixture, + 'two_discs': sample_two_discs, + 'disc_grid': sample_disc_grid, + 'spiral': sample_spiral, + 'mnist': sample_mnist, +} + ###################################################################### -model = nn.Sequential( - nn.Linear(2, 32), - nn.ReLU(), - nn.Linear(32, 32), - nn.ReLU(), - nn.Linear(32, 1), +parser = argparse.ArgumentParser( + description = '''A minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel +"Denoising Diffusion Probabilistic Models" (2020) +https://arxiv.org/abs/2006.11239''', + + formatter_class = argparse.ArgumentDefaultsHelpFormatter ) +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed, < 0 is no seeding') + +parser.add_argument('--nb_epochs', + type = int, default = 100, + help = 'How many epochs') + +parser.add_argument('--batch_size', + type = int, default = 25, + help = 'Batch size') + +parser.add_argument('--nb_samples', + type = int, default = 25000, + help = 'Number of training examples') + +parser.add_argument('--learning_rate', + type = float, default = 1e-3, + help = 'Learning rate') + +parser.add_argument('--ema_decay', + type = float, default = 0.9999, + help = 'EMA decay, < 0 is no EMA') + +data_list = ', '.join( [ str(k) for k in samplers ]) + +parser.add_argument('--data', + type = str, default = 'gaussian_mixture', + help = f'Toy data-set to use: {data_list}') + +args = parser.parse_args() + +if args.seed >= 0: + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + # torch.use_deterministic_algorithms(True) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + ###################################################################### -# Train -nb_samples = 25000 -nb_epochs = 250 -batch_size = 100 +class EMA: + def __init__(self, model, decay): + self.model = model + self.decay = decay + if self.decay < 0: return + self.ema = { } + with torch.no_grad(): + for p in model.parameters(): + self.ema[p] = p.clone() + + def step(self): + if self.decay < 0: return + with torch.no_grad(): + for p in self.model.parameters(): + self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p) + + def copy(self): + if self.decay < 0: return + with torch.no_grad(): + for p in self.model.parameters(): + p.copy_(self.ema[p]) + +###################################################################### + +class ConvNet(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + + ks, nc = 5, 64 + + self.core = nn.Sequential( + nn.Conv2d(in_channels, nc, ks, padding = ks//2), + nn.ReLU(), + nn.Conv2d(nc, nc, ks, padding = ks//2), + nn.ReLU(), + nn.Conv2d(nc, nc, ks, padding = ks//2), + nn.ReLU(), + nn.Conv2d(nc, nc, ks, padding = ks//2), + nn.ReLU(), + nn.Conv2d(nc, nc, ks, padding = ks//2), + nn.ReLU(), + nn.Conv2d(nc, out_channels, ks, padding = ks//2), + ) + + def forward(self, x): + return self.core(x) + +###################################################################### +# Data + +try: + train_input = samplers[args.data](args.nb_samples).to(device) +except KeyError: + print(f'unknown data {args.data}') + exit(1) + +train_mean, train_std = train_input.mean(), train_input.std() + +###################################################################### +# Model + +if train_input.dim() == 2: + nh = 64 + + model = nn.Sequential( + nn.Linear(train_input.size(1) + 1, nh), + nn.ReLU(), + nn.Linear(nh, nh), + nn.ReLU(), + nn.Linear(nh, nh), + nn.ReLU(), + nn.Linear(nh, train_input.size(1)), + ) + +elif train_input.dim() == 4: + + model = ConvNet(train_input.size(1) + 1, train_input.size(1)) -train_input = sample_phi(nb_samples)[:, None] +model.to(device) + +###################################################################### +# Train T = 1000 -beta = torch.linspace(1e-4, 0.02, T) +beta = torch.linspace(1e-4, 0.02, T, device = device) alpha = 1 - beta alpha_bar = alpha.log().cumsum(0).exp() sigma = beta.sqrt() -for k in range(nb_epochs): +ema = EMA(model, decay = args.ema_decay) + +for k in range(args.nb_epochs): + acc_loss = 0 - optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4 * (1 - k / nb_epochs) ) - - for x0 in train_input.split(batch_size): - t = torch.randint(T, (x0.size(0), 1)) - eps = torch.randn(x0.size()) - input = alpha_bar[t].sqrt() * x0 + (1 - alpha_bar[t]).sqrt() * eps - input = torch.cat((input, 2 * t / T - 1), 1) - output = model(input) - loss = (eps - output).pow(2).mean() + optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) + + for x0 in train_input.split(args.batch_size): + x0 = (x0 - train_mean) / train_std + t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device) + eps = torch.randn_like(x0) + input = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps + input = torch.cat((input, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1) + loss = (eps - model(input)).pow(2).mean() + acc_loss += loss.item() * x0.size(0) + optimizer.zero_grad() loss.backward() optimizer.step() - acc_loss += loss.item() + ema.step() + + if k%10 == 0: print(f'{k} {acc_loss / train_input.size(0)}') - if k%10 == 0: print(k, loss.item()) +ema.copy() + +###################################################################### +# Generate + +def generate(size, model): + with torch.no_grad(): + x = torch.randn(size, device = device) + + for t in range(T-1, -1, -1): + z = torch.zeros_like(x) if t == 0 else torch.randn_like(x) + input = torch.cat((x, torch.full_like(x[:,:1], t / (T - 1) - 0.5)), 1) + x = 1/torch.sqrt(alpha[t]) \ + * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * model(input)) \ + + sigma[t] * z + + x = x * train_std + train_mean + + return x ###################################################################### # Plot -x = torch.randn(10000, 1) +model.eval() + +if train_input.dim() == 2: + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + + if train_input.size(1) == 1: + + x = generate((10000, 1), model) + + ax.set_xlim(-1.25, 1.25) + + d = train_input.flatten().detach().to('cpu').numpy() + ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'stepfilled', color = 'lightblue', label = 'Train') + + d = x.flatten().detach().to('cpu').numpy() + ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'step', color = 'red', label = 'Synthesis') + + ax.legend(frameon = False, loc = 2) + + elif train_input.size(1) == 2: + + x = generate((1000, 2), model) + + ax.set_xlim(-1.25, 1.25) + ax.set_ylim(-1.25, 1.25) + ax.set(aspect = 1) -for t in range(T-1, -1, -1): - z = torch.zeros(x.size()) if t == 0 else torch.randn(x.size()) - input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1) - x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) + sigma[t] * z + d = train_input[:x.size(0)].detach().to('cpu').numpy() + ax.scatter(d[:, 0], d[:, 1], + color = 'lightblue', label = 'Train') -fig = plt.figure() -ax = fig.add_subplot(1, 1, 1) -ax.set_xlim(-1.25, 1.25) + d = x.detach().to('cpu').numpy() + ax.scatter(d[:, 0], d[:, 1], + facecolors = 'none', color = 'red', label = 'Synthesis') -d = train_input.flatten().detach().numpy() -ax.hist(d, 25, (-1, 1), histtype = 'stepfilled', color = 'lightblue', density = True, label = 'Train') + ax.legend(frameon = False, loc = 2) -d = x.flatten().detach().numpy() -ax.hist(d, 25, (-1, 1), histtype = 'step', color = 'red', density = True, label = 'Synthesis') + filename = f'diffusion_{args.data}.pdf' + print(f'saving {filename}') + fig.savefig(filename, bbox_inches='tight') -filename = 'diffusion.pdf' -fig.savefig(filename, bbox_inches='tight') + if hasattr(plt.get_current_fig_manager(), 'window'): + plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) + plt.show() -plt.show() +elif train_input.dim() == 4: + x = generate((128,) + train_input.size()[1:], model) + x = 1 - x.clamp(min = 0, max = 255) / 255 + torchvision.utils.save_image(x, f'diffusion_{args.data}.png', nrow = 16, pad_value = 0.8) ######################################################################