X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=minidiffusion.py;h=e7be8c1c8651a3cc4ee4a578586e2e4dd3c29bcf;hb=8d71d2f43cec159c7ca368c1dc4fa76f061d13b7;hp=e1f6abd102b039edc9b3408ca03e51b11ea855d9;hpb=2d19b3ce1dca606a27cd8d0e978ebe8710f7995c;p=pytorch.git diff --git a/minidiffusion.py b/minidiffusion.py index e1f6abd..e7be8c1 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -11,6 +11,7 @@ import matplotlib.pyplot as plt import torch, torchvision from torch import nn +from torch.nn import functional as F device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -113,6 +114,9 @@ parser.add_argument('--data', type = str, default = 'gaussian_mixture', help = f'Toy data-set to use: {data_list}') +parser.add_argument('--no_window', + action='store_true', default = False) + args = parser.parse_args() if args.seed >= 0: @@ -146,6 +150,20 @@ class EMA: ###################################################################### +# Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an +# additional dimension / channel + +class TimeAppender(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, u): + x, t = u + if not torch.is_tensor(t): + t = x.new_full((x.size(0),), t) + t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:,:1]) + return torch.cat((x, t), 1) + class ConvNet(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() @@ -153,7 +171,8 @@ class ConvNet(nn.Module): ks, nc = 5, 64 self.core = nn.Sequential( - nn.Conv2d(in_channels, nc, ks, padding = ks//2), + TimeAppender(), + nn.Conv2d(in_channels + 1, nc, ks, padding = ks//2), nn.ReLU(), nn.Conv2d(nc, nc, ks, padding = ks//2), nn.ReLU(), @@ -166,8 +185,8 @@ class ConvNet(nn.Module): nn.Conv2d(nc, out_channels, ks, padding = ks//2), ) - def forward(self, x): - return self.core(x) + def forward(self, u): + return self.core(u) ###################################################################### # Data @@ -187,6 +206,7 @@ if train_input.dim() == 2: nh = 256 model = nn.Sequential( + TimeAppender(), nn.Linear(train_input.size(1) + 1, nh), nn.ReLU(), nn.Linear(nh, nh), @@ -198,7 +218,7 @@ if train_input.dim() == 2: elif train_input.dim() == 4: - model = ConvNet(train_input.size(1) + 1, train_input.size(1)) + model = ConvNet(train_input.size(1), train_input.size(1)) model.to(device) @@ -207,15 +227,17 @@ print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}') ###################################################################### # Generate -def generate(size, alpha, alpha_bar, sigma, model): +def generate(size, T, alpha, alpha_bar, sigma, model, train_mean, train_std): + with torch.no_grad(): + x = torch.randn(size, device = device) for t in range(T-1, -1, -1): + output = model((x, t / (T - 1) - 0.5)) z = torch.zeros_like(x) if t == 0 else torch.randn_like(x) - input = torch.cat((x, torch.full_like(x[:,:1], t / (T - 1) - 0.5)), 1) x = 1/torch.sqrt(alpha[t]) \ - * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * model(input)) \ + * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * output) \ + sigma[t] * z x = x * train_std + train_mean @@ -243,8 +265,8 @@ for k in range(args.nb_epochs): t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device) eps = torch.randn_like(x0) xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps - input = torch.cat((xt, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1) - loss = (eps - model(input)).pow(2).mean() + output = model((xt, t / (T - 1) - 0.5)) + loss = (eps - output).pow(2).mean() acc_loss += loss.item() * x0.size(0) optimizer.zero_grad() @@ -262,63 +284,103 @@ if ema is not None: ema.copy_to_model() model.eval() -if train_input.dim() == 2: +######################################## +# Nx1 -> histogram +if train_input.dim() == 2 and train_input.size(1) == 1: fig = plt.figure() + fig.set_figheight(5) + fig.set_figwidth(8) + ax = fig.add_subplot(1, 1, 1) - if train_input.size(1) == 1: + x = generate((10000, 1), T, alpha, alpha_bar, sigma, + model, train_mean, train_std) - x = generate((10000, 1), alpha, alpha_bar, sigma, model) + ax.set_xlim(-1.25, 1.25) + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) - ax.set_xlim(-1.25, 1.25) - ax.spines.right.set_visible(False) - ax.spines.top.set_visible(False) + d = train_input.flatten().detach().to('cpu').numpy() + ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'bar', edgecolor = 'white', color = 'lightblue', label = 'Train') - d = train_input.flatten().detach().to('cpu').numpy() - ax.hist(d, 25, (-1, 1), - density = True, - histtype = 'stepfilled', color = 'lightblue', label = 'Train') + d = x.flatten().detach().to('cpu').numpy() + ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'step', color = 'red', label = 'Synthesis') - d = x.flatten().detach().to('cpu').numpy() - ax.hist(d, 25, (-1, 1), - density = True, - histtype = 'step', color = 'red', label = 'Synthesis') + ax.legend(frameon = False, loc = 2) - ax.legend(frameon = False, loc = 2) + filename = f'minidiffusion_{args.data}.pdf' + print(f'saving {filename}') + fig.savefig(filename, bbox_inches='tight') - elif train_input.size(1) == 2: + if not args.no_window and hasattr(plt.get_current_fig_manager(), 'window'): + plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) + plt.show() - x = generate((1000, 2), alpha, alpha_bar, sigma, model) +######################################## +# Nx2 -> scatter plot +elif train_input.dim() == 2 and train_input.size(1) == 2: - ax.set_xlim(-1.5, 1.5) - ax.set_ylim(-1.5, 1.5) - ax.set(aspect = 1) - ax.spines.right.set_visible(False) - ax.spines.top.set_visible(False) + fig = plt.figure() + fig.set_figheight(6) + fig.set_figwidth(6) + + ax = fig.add_subplot(1, 1, 1) - d = x.detach().to('cpu').numpy() - ax.scatter(d[:, 0], d[:, 1], - s = 2.0, color = 'red', label = 'Synthesis') + x = generate((1000, 2), T, alpha, alpha_bar, sigma, + model, train_mean, train_std) - d = train_input[:x.size(0)].detach().to('cpu').numpy() - ax.scatter(d[:, 0], d[:, 1], - s = 2.0, color = 'gray', label = 'Train') + ax.set_xlim(-1.5, 1.5) + ax.set_ylim(-1.5, 1.5) + ax.set(aspect = 1) + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) - ax.legend(frameon = False, loc = 2) + d = train_input[:x.size(0)].detach().to('cpu').numpy() + ax.scatter(d[:, 0], d[:, 1], + s = 2.5, color = 'gray', label = 'Train') - filename = f'diffusion_{args.data}.pdf' + d = x.detach().to('cpu').numpy() + ax.scatter(d[:, 0], d[:, 1], + s = 2.0, color = 'red', label = 'Synthesis') + + ax.legend(frameon = False, loc = 2) + + filename = f'minidiffusion_{args.data}.pdf' print(f'saving {filename}') fig.savefig(filename, bbox_inches='tight') - if hasattr(plt.get_current_fig_manager(), 'window'): + if not args.no_window and hasattr(plt.get_current_fig_manager(), 'window'): plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) plt.show() +######################################## +# NxCxHxW -> image elif train_input.dim() == 4: - x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma, model) - x = 1 - x.clamp(min = 0, max = 255) / 255 - torchvision.utils.save_image(x, f'diffusion_{args.data}.png', nrow = 16, pad_value = 0.8) + x = generate((128,) + train_input.size()[1:], T, alpha, alpha_bar, sigma, + model, train_mean, train_std) + + x = torchvision.utils.make_grid(x.clamp(min = 0, max = 255), + nrow = 16, padding = 1, pad_value = 64) + x = F.pad(x, pad = (2, 2, 2, 2), value = 64)[None] + + t = torchvision.utils.make_grid(train_input[:128], + nrow = 16, padding = 1, pad_value = 64) + t = F.pad(t, pad = (2, 2, 2, 2), value = 64)[None] + + result = 1 - torch.cat((t, x), 2) / 255 + + filename = f'minidiffusion_{args.data}.png' + print(f'saving {filename}') + torchvision.utils.save_image(result, filename) + +else: + + print(f'cannot plot result of size {train_input.size()}') ######################################################################