From: Francois Fleuret Date: Sat, 13 Aug 2022 00:31:14 +0000 (+0200) Subject: Add command line arguments and cuda support. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=b52a28b72ae3a07f61aaa9fa5b6d063bbe5d4dda;p=pytorch.git Add command line arguments and cuda support. --- diff --git a/minidiffusion.py b/minidiffusion.py index 0f5948e..8d8dac0 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -5,60 +5,128 @@ # Written by Francois Fleuret -# Minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel -# "Denoising Diffusion Probabilistic Models" (2020) -# -# https://arxiv.org/abs/2006.11239 +import math, argparse -import math import matplotlib.pyplot as plt + import torch from torch import nn +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +###################################################################### + +def sample_gaussian_mixture(nb): + p, std = 0.3, 0.2 + result = torch.empty(nb, 1, device = device).normal_(0, std) + result = result + torch.sign(torch.rand(result.size(), device = device) - p) / 2 + return result + +def sample_arc(nb): + theta = torch.rand(nb, device = device) * math.pi + rho = torch.rand(nb, device = device) * 0.1 + 0.7 + result = torch.empty(nb, 2, device = device) + result[:, 0] = theta.cos() * rho + result[:, 1] = theta.sin() * rho + return result + +def sample_spiral(nb): + u = torch.rand(nb, device = device) + rho = u * 0.65 + 0.25 + torch.rand(nb, device = device) * 0.15 + theta = u * math.pi * 3 + result = torch.empty(nb, 2, device = device) + result[:, 0] = theta.cos() * rho + result[:, 1] = theta.sin() * rho + return result + +samplers = { + 'gaussian_mixture': sample_gaussian_mixture, + 'arc': sample_arc, + 'spiral': sample_spiral, +} + +###################################################################### + +parser = argparse.ArgumentParser( + description = '''A minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel +"Denoising Diffusion Probabilistic Models" (2020) +https://arxiv.org/abs/2006.11239''', + + formatter_class = argparse.ArgumentDefaultsHelpFormatter +) + +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed, < 0 is no seeding') + +parser.add_argument('--nb_epochs', + type = int, default = 100, + help = 'How many epochs') + +parser.add_argument('--batch_size', + type = int, default = 25, + help = 'Batch size') + +parser.add_argument('--nb_samples', + type = int, default = 25000, + help = 'Number of training examples') + +parser.add_argument('--learning_rate', + type = float, default = 1e-3, + help = 'Learning rate') + +parser.add_argument('--ema_decay', + type = float, default = 0.9999, + help = 'EMA decay, < 0 means no EMA') + +data_list = ', '.join( [ str(k) for k in samplers ]) + +parser.add_argument('--data', + type = str, default = 'gaussian_mixture', + help = f'Toy data-set to use: {data_list}') + +args = parser.parse_args() + +if args.seed >= 0: + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + # torch.use_deterministic_algorithms(True) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + ###################################################################### class EMA: - def __init__(self, model, decay = 0.9999): + def __init__(self, model, decay): self.model = model self.decay = decay + if self.decay < 0: return self.ema = { } with torch.no_grad(): for p in model.parameters(): self.ema[p] = p.clone() def step(self): + if self.decay < 0: return with torch.no_grad(): for p in self.model.parameters(): self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p) def copy(self): + if self.decay < 0: return with torch.no_grad(): for p in self.model.parameters(): p.copy_(self.ema[p]) -###################################################################### - -def sample_gaussian_mixture(nb): - p, std = 0.3, 0.2 - result = torch.empty(nb, 1).normal_(0, std) - result = result + torch.sign(torch.rand(result.size()) - p) / 2 - return result - -def sample_arc(nb): - theta = torch.rand(nb) * math.pi - rho = torch.rand(nb) * 0.1 + 0.7 - result = torch.empty(nb, 2) - result[:, 0] = theta.cos() * rho - result[:, 1] = theta.sin() * rho - return result - ###################################################################### # Train -nb_samples = 25000 - -train_input = sample_gaussian_mixture(nb_samples) -#train_input = sample_arc(nb_samples) +try: + train_input = samplers[args.data](args.nb_samples) +except KeyError: + print(f'unknown data {args.data}') + exit(1) ###################################################################### @@ -69,28 +137,27 @@ model = nn.Sequential( nn.ReLU(), nn.Linear(nh, nh), nn.ReLU(), + nn.Linear(nh, nh), + nn.ReLU(), nn.Linear(nh, train_input.size(1)), -) - -nb_epochs = 50 -batch_size = 25 +).to(device) T = 1000 -beta = torch.linspace(1e-4, 0.02, T) +beta = torch.linspace(1e-4, 0.02, T, device = device) alpha = 1 - beta alpha_bar = alpha.log().cumsum(0).exp() sigma = beta.sqrt() -ema = EMA(model) +ema = EMA(model, decay = args.ema_decay) -for k in range(nb_epochs): +for k in range(args.nb_epochs): acc_loss = 0 - optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) + optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) - for x0 in train_input.split(batch_size): - t = torch.randint(T, (x0.size(0), 1)) - eps = torch.randn(x0.size()) + for x0 in train_input.split(args.batch_size): + t = torch.randint(T, (x0.size(0), 1), device = device) + eps = torch.randn(x0.size(), device = device) input = alpha_bar[t].sqrt() * x0 + (1 - alpha_bar[t]).sqrt() * eps input = torch.cat((input, 2 * t / T - 1), 1) output = model(input) @@ -99,22 +166,22 @@ for k in range(nb_epochs): loss.backward() optimizer.step() - acc_loss += loss.item() + acc_loss += loss.item() * x0.size(0) ema.step() - if k%10 == 0: print(k, loss.item()) + if k%10 == 0: print(f'{k} {acc_loss / train_input.size(0)}') ema.copy() ###################################################################### # Generate -x = torch.randn(10000, train_input.size(1)) +x = torch.randn(10000, train_input.size(1), device = device) for t in range(T-1, -1, -1): - z = torch.zeros(x.size()) if t == 0 else torch.randn(x.size()) - input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1) + z = torch.zeros(x.size(), device = device) if t == 0 else torch.randn(x.size(), device = device) + input = torch.cat((x, torch.ones(x.size(0), 1, device = device) * 2 * t / T - 1), 1) x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) \ + sigma[t] * z @@ -128,12 +195,12 @@ if train_input.size(1) == 1: ax.set_xlim(-1.25, 1.25) - d = train_input.flatten().detach().numpy() + d = train_input.flatten().detach().to('cpu').numpy() ax.hist(d, 25, (-1, 1), density = True, histtype = 'stepfilled', color = 'lightblue', label = 'Train') - d = x.flatten().detach().numpy() + d = x.flatten().detach().to('cpu').numpy() ax.hist(d, 25, (-1, 1), density = True, histtype = 'step', color = 'red', label = 'Synthesis') @@ -146,21 +213,22 @@ elif train_input.size(1) == 2: ax.set_ylim(-1.25, 1.25) ax.set(aspect = 1) - d = train_input[:200].detach().numpy() + d = train_input[:200].detach().to('cpu').numpy() ax.scatter(d[:, 0], d[:, 1], color = 'lightblue', label = 'Train') - d = x[:200].detach().numpy() + d = x[:200].detach().to('cpu').numpy() ax.scatter(d[:, 0], d[:, 1], color = 'red', label = 'Synthesis') ax.legend(frameon = False, loc = 2) -filename = 'diffusion.pdf' +filename = f'diffusion_{args.data}.pdf' print(f'saving {filename}') fig.savefig(filename, bbox_inches='tight') -plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) -plt.show() +if hasattr(plt.get_current_fig_manager(), 'window'): + plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) + plt.show() ######################################################################