X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=minidiffusion.py;h=27842d9e0d785f4b1dac433702f8a7782db69444;hb=989834728791abe50c14aad4294e1ef10d9bbf35;hp=075eb8208455045bc5df1d869e26de639419eee1;hpb=142b09825bec53a432795cb34c2cc325b0e994c2;p=pytorch.git diff --git a/minidiffusion.py b/minidiffusion.py index 075eb82..27842d9 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -9,40 +9,68 @@ import math, argparse import matplotlib.pyplot as plt -import torch +import torch, torchvision from torch import nn device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f'device {device}') + ###################################################################### def sample_gaussian_mixture(nb): p, std = 0.3, 0.2 - result = torch.empty(nb, 1, device = device).normal_(0, std) - result = result + torch.sign(torch.rand(result.size(), device = device) - p) / 2 + result = torch.randn(nb, 1) * std + result = result + torch.sign(torch.rand(result.size()) - p) / 2 return result -def sample_arc(nb): - theta = torch.rand(nb, device = device) * math.pi - rho = torch.rand(nb, device = device) * 0.1 + 0.7 - result = torch.empty(nb, 2, device = device) - result[:, 0] = theta.cos() * rho - result[:, 1] = theta.sin() * rho +def sample_ramp(nb): + result = torch.min(torch.rand(nb, 1), torch.rand(nb, 1)) + return result + +def sample_two_discs(nb): + a = torch.rand(nb) * math.pi * 2 + b = torch.rand(nb).sqrt() + q = (torch.rand(nb) <= 0.5).long() + b = b * (0.3 + 0.2 * q) + result = torch.empty(nb, 2) + result[:, 0] = a.cos() * b - 0.5 + q + result[:, 1] = a.sin() * b - 0.5 + q + return result + +def sample_disc_grid(nb): + a = torch.rand(nb) * math.pi * 2 + b = torch.rand(nb).sqrt() + N = 4 + q = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2) + r = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2) + b = b * 0.1 + result = torch.empty(nb, 2) + result[:, 0] = a.cos() * b + q + result[:, 1] = a.sin() * b + r return result def sample_spiral(nb): - u = torch.rand(nb, device = device) - rho = u * 0.65 + 0.25 + torch.rand(nb, device = device) * 0.15 + u = torch.rand(nb) + rho = u * 0.65 + 0.25 + torch.rand(nb) * 0.15 theta = u * math.pi * 3 - result = torch.empty(nb, 2, device = device) + result = torch.empty(nb, 2) result[:, 0] = theta.cos() * rho result[:, 1] = theta.sin() * rho return result +def sample_mnist(nb): + train_set = torchvision.datasets.MNIST(root = './data/', train = True, download = True) + result = train_set.data[:nb].to(device).view(-1, 1, 28, 28).float() + return result + samplers = { 'gaussian_mixture': sample_gaussian_mixture, - 'arc': sample_arc, + 'ramp': sample_ramp, + 'two_discs': sample_two_discs, + 'disc_grid': sample_disc_grid, 'spiral': sample_spiral, + 'mnist': sample_mnist, } ###################################################################### @@ -77,7 +105,7 @@ parser.add_argument('--learning_rate', parser.add_argument('--ema_decay', type = float, default = 0.9999, - help = 'EMA decay, < 0 is no EMA') + help = 'EMA decay, <= 0 is no EMA') data_list = ', '.join( [ str(k) for k in samplers ]) @@ -101,46 +129,119 @@ class EMA: def __init__(self, model, decay): self.model = model self.decay = decay - if self.decay < 0: return - self.ema = { } + self.mem = { } with torch.no_grad(): for p in model.parameters(): - self.ema[p] = p.clone() + self.mem[p] = p.clone() def step(self): - if self.decay < 0: return with torch.no_grad(): for p in self.model.parameters(): - self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p) + self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p) - def copy(self): - if self.decay < 0: return + def copy_to_model(self): with torch.no_grad(): for p in self.model.parameters(): - p.copy_(self.ema[p]) + p.copy_(self.mem[p]) ###################################################################### -# Train + +# Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an +# additional dimension / channel + +class TimeAppender(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, u): + x, t = u + if not torch.is_tensor(t): + t = x.new_full((x.size(0),), t) + t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:,:1]) + return torch.cat((x, t), 1) + +class ConvNet(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + + ks, nc = 5, 64 + + self.core = nn.Sequential( + TimeAppender(), + nn.Conv2d(in_channels + 1, nc, ks, padding = ks//2), + nn.ReLU(), + nn.Conv2d(nc, nc, ks, padding = ks//2), + nn.ReLU(), + nn.Conv2d(nc, nc, ks, padding = ks//2), + nn.ReLU(), + nn.Conv2d(nc, nc, ks, padding = ks//2), + nn.ReLU(), + nn.Conv2d(nc, nc, ks, padding = ks//2), + nn.ReLU(), + nn.Conv2d(nc, out_channels, ks, padding = ks//2), + ) + + def forward(self, u): + return self.core(u) + +###################################################################### +# Data try: - train_input = samplers[args.data](args.nb_samples) + train_input = samplers[args.data](args.nb_samples).to(device) except KeyError: print(f'unknown data {args.data}') exit(1) +train_mean, train_std = train_input.mean(), train_input.std() + +###################################################################### +# Model + +if train_input.dim() == 2: + nh = 256 + + model = nn.Sequential( + TimeAppender(), + nn.Linear(train_input.size(1) + 1, nh), + nn.ReLU(), + nn.Linear(nh, nh), + nn.ReLU(), + nn.Linear(nh, nh), + nn.ReLU(), + nn.Linear(nh, train_input.size(1)), + ) + +elif train_input.dim() == 4: + + model = ConvNet(train_input.size(1), train_input.size(1)) + +model.to(device) + +print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}') + ###################################################################### +# Generate + +def generate(size, alpha, alpha_bar, sigma, model, train_mean, train_std): + + with torch.no_grad(): + + x = torch.randn(size, device = device) + + for t in range(T-1, -1, -1): + output = model((x, t / (T - 1) - 0.5)) + z = torch.zeros_like(x) if t == 0 else torch.randn_like(x) + x = 1/torch.sqrt(alpha[t]) \ + * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * output) \ + + sigma[t] * z -nh = 64 + x = x * train_std + train_mean -model = nn.Sequential( - nn.Linear(train_input.size(1) + 1, nh), - nn.ReLU(), - nn.Linear(nh, nh), - nn.ReLU(), - nn.Linear(nh, nh), - nn.ReLU(), - nn.Linear(nh, train_input.size(1)), -).to(device) + return x + +###################################################################### +# Train T = 1000 beta = torch.linspace(1e-4, 0.02, T, device = device) @@ -148,7 +249,7 @@ alpha = 1 - beta alpha_bar = alpha.log().cumsum(0).exp() sigma = beta.sqrt() -ema = EMA(model, decay = args.ema_decay) +ema = EMA(model, decay = args.ema_decay) if args.ema_decay > 0 else None for k in range(args.nb_epochs): @@ -156,79 +257,95 @@ for k in range(args.nb_epochs): optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) for x0 in train_input.split(args.batch_size): - t = torch.randint(T, (x0.size(0), 1), device = device) - eps = torch.randn(x0.size(), device = device) - input = alpha_bar[t].sqrt() * x0 + (1 - alpha_bar[t]).sqrt() * eps - input = torch.cat((input, 2 * t / T - 1), 1) - output = model(input) + x0 = (x0 - train_mean) / train_std + t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device) + eps = torch.randn_like(x0) + xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps + output = model((xt, t / (T - 1) - 0.5)) loss = (eps - output).pow(2).mean() + acc_loss += loss.item() * x0.size(0) + optimizer.zero_grad() loss.backward() optimizer.step() - acc_loss += loss.item() * x0.size(0) - - ema.step() + if ema is not None: ema.step() - if k%10 == 0: print(f'{k} {acc_loss / train_input.size(0)}') + print(f'{k} {acc_loss / train_input.size(0)}') -ema.copy() +if ema is not None: ema.copy_to_model() ###################################################################### -# Generate +# Plot -x = torch.randn(10000, train_input.size(1), device = device) +model.eval() -for t in range(T-1, -1, -1): - z = torch.zeros(x.size(), device = device) if t == 0 else torch.randn(x.size(), device = device) - input = torch.cat((x, torch.ones(x.size(0), 1, device = device) * 2 * t / T - 1), 1) - x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) \ - + sigma[t] * z +if train_input.dim() == 2: -###################################################################### -# Plot + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + + # Nx1 -> histogram + if train_input.size(1) == 1: + + x = generate((10000, 1), alpha, alpha_bar, sigma, + model, train_mean, train_std) + + ax.set_xlim(-1.25, 1.25) + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) + + d = train_input.flatten().detach().to('cpu').numpy() + ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'stepfilled', color = 'lightblue', label = 'Train') -fig = plt.figure() -ax = fig.add_subplot(1, 1, 1) + d = x.flatten().detach().to('cpu').numpy() + ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'step', color = 'red', label = 'Synthesis') -if train_input.size(1) == 1: + ax.legend(frameon = False, loc = 2) - ax.set_xlim(-1.25, 1.25) + # Nx2 -> scatter plot + elif train_input.size(1) == 2: - d = train_input.flatten().detach().to('cpu').numpy() - ax.hist(d, 25, (-1, 1), - density = True, - histtype = 'stepfilled', color = 'lightblue', label = 'Train') + x = generate((1000, 2), alpha, alpha_bar, sigma, + model, train_mean, train_std) - d = x.flatten().detach().to('cpu').numpy() - ax.hist(d, 25, (-1, 1), - density = True, - histtype = 'step', color = 'red', label = 'Synthesis') + ax.set_xlim(-1.5, 1.5) + ax.set_ylim(-1.5, 1.5) + ax.set(aspect = 1) + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) - ax.legend(frameon = False, loc = 2) + d = x.detach().to('cpu').numpy() + ax.scatter(d[:, 0], d[:, 1], + s = 2.0, color = 'red', label = 'Synthesis') -elif train_input.size(1) == 2: + d = train_input[:x.size(0)].detach().to('cpu').numpy() + ax.scatter(d[:, 0], d[:, 1], + s = 2.0, color = 'gray', label = 'Train') - ax.set_xlim(-1.25, 1.25) - ax.set_ylim(-1.25, 1.25) - ax.set(aspect = 1) + ax.legend(frameon = False, loc = 2) - d = train_input[:200].detach().to('cpu').numpy() - ax.scatter(d[:, 0], d[:, 1], - color = 'lightblue', label = 'Train') + filename = f'diffusion_{args.data}.pdf' + print(f'saving {filename}') + fig.savefig(filename, bbox_inches='tight') - d = x[:200].detach().to('cpu').numpy() - ax.scatter(d[:, 0], d[:, 1], - color = 'red', label = 'Synthesis') + if hasattr(plt.get_current_fig_manager(), 'window'): + plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) + plt.show() - ax.legend(frameon = False, loc = 2) +# NxCxHxW -> image +elif train_input.dim() == 4: -filename = f'diffusion_{args.data}.pdf' -print(f'saving {filename}') -fig.savefig(filename, bbox_inches='tight') + x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma, + model, train_mean, train_std) + x = 1 - x.clamp(min = 0, max = 255) / 255 -if hasattr(plt.get_current_fig_manager(), 'window'): - plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) - plt.show() + filename = f'diffusion_{args.data}.png' + print(f'saving {filename}') + torchvision.utils.save_image(x, filename, nrow = 16, pad_value = 0.8) ######################################################################