X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=minidiffusion.py;h=841dd2a075dd2ea33c8cd126ae4adb5237338e19;hb=72584ecbc98b6171e1e2e4193ef63fedb5a55b7b;hp=6fd85643e1b83d6555c251e55e19146231e4e2a4;hpb=f419d073353e4b468eeb0a4fc82dee37e447b341;p=pytorch.git diff --git a/minidiffusion.py b/minidiffusion.py index 6fd8564..841dd2a 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -14,6 +14,8 @@ from torch import nn device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f'device {device}') + ###################################################################### def sample_gaussian_mixture(nb): @@ -103,7 +105,7 @@ parser.add_argument('--learning_rate', parser.add_argument('--ema_decay', type = float, default = 0.9999, - help = 'EMA decay, < 0 is no EMA') + help = 'EMA decay, <= 0 is no EMA') data_list = ', '.join( [ str(k) for k in samplers ]) @@ -127,23 +129,20 @@ class EMA: def __init__(self, model, decay): self.model = model self.decay = decay - if self.decay < 0: return - self.ema = { } + self.mem = { } with torch.no_grad(): for p in model.parameters(): - self.ema[p] = p.clone() + self.mem[p] = p.clone() def step(self): - if self.decay < 0: return with torch.no_grad(): for p in self.model.parameters(): - self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p) + self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p) - def copy(self): - if self.decay < 0: return + def copy_to_model(self): with torch.no_grad(): for p in self.model.parameters(): - p.copy_(self.ema[p]) + p.copy_(self.mem[p]) ###################################################################### @@ -205,6 +204,26 @@ model.to(device) print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}') +###################################################################### +# Generate + +def generate(size, alpha, alpha_bar, sigma, model, train_mean, train_std): + + with torch.no_grad(): + + x = torch.randn(size, device = device) + + for t in range(T-1, -1, -1): + z = torch.zeros_like(x) if t == 0 else torch.randn_like(x) + input = torch.cat((x, torch.full_like(x[:,:1], t / (T - 1) - 0.5)), 1) + x = 1/torch.sqrt(alpha[t]) \ + * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * model(input)) \ + + sigma[t] * z + + x = x * train_std + train_mean + + return x + ###################################################################### # Train @@ -214,7 +233,7 @@ alpha = 1 - beta alpha_bar = alpha.log().cumsum(0).exp() sigma = beta.sqrt() -ema = EMA(model, decay = args.ema_decay) +ema = EMA(model, decay = args.ema_decay) if args.ema_decay > 0 else None for k in range(args.nb_epochs): @@ -225,8 +244,8 @@ for k in range(args.nb_epochs): x0 = (x0 - train_mean) / train_std t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device) eps = torch.randn_like(x0) - input = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps - input = torch.cat((input, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1) + xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps + input = torch.cat((xt, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1) loss = (eps - model(input)).pow(2).mean() acc_loss += loss.item() * x0.size(0) @@ -234,29 +253,11 @@ for k in range(args.nb_epochs): loss.backward() optimizer.step() - ema.step() + if ema is not None: ema.step() print(f'{k} {acc_loss / train_input.size(0)}') -ema.copy() - -###################################################################### -# Generate - -def generate(size, model): - with torch.no_grad(): - x = torch.randn(size, device = device) - - for t in range(T-1, -1, -1): - z = torch.zeros_like(x) if t == 0 else torch.randn_like(x) - input = torch.cat((x, torch.full_like(x[:,:1], t / (T - 1) - 0.5)), 1) - x = 1/torch.sqrt(alpha[t]) \ - * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * model(input)) \ - + sigma[t] * z - - x = x * train_std + train_mean - - return x +if ema is not None: ema.copy_to_model() ###################################################################### # Plot @@ -264,14 +265,18 @@ def generate(size, model): model.eval() if train_input.dim() == 2: + fig = plt.figure() ax = fig.add_subplot(1, 1, 1) if train_input.size(1) == 1: - x = generate((10000, 1), model) + x = generate((10000, 1), alpha, alpha_bar, sigma, + model, train_mean, train_std) ax.set_xlim(-1.25, 1.25) + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) d = train_input.flatten().detach().to('cpu').numpy() ax.hist(d, 25, (-1, 1), @@ -287,7 +292,8 @@ if train_input.dim() == 2: elif train_input.size(1) == 2: - x = generate((1000, 2), model) + x = generate((1000, 2), alpha, alpha_bar, sigma, + model, train_mean, train_std) ax.set_xlim(-1.5, 1.5) ax.set_ylim(-1.5, 1.5) @@ -297,11 +303,11 @@ if train_input.dim() == 2: d = x.detach().to('cpu').numpy() ax.scatter(d[:, 0], d[:, 1], - facecolors = 'none', color = 'red', label = 'Synthesis') + s = 2.0, color = 'red', label = 'Synthesis') d = train_input[:x.size(0)].detach().to('cpu').numpy() ax.scatter(d[:, 0], d[:, 1], - s = 1.0, color = 'blue', label = 'Train') + s = 2.0, color = 'gray', label = 'Train') ax.legend(frameon = False, loc = 2) @@ -314,7 +320,9 @@ if train_input.dim() == 2: plt.show() elif train_input.dim() == 4: - x = generate((128,) + train_input.size()[1:], model) + + x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma, + model, train_mean, train_std) x = 1 - x.clamp(min = 0, max = 255) / 255 torchvision.utils.save_image(x, f'diffusion_{args.data}.png', nrow = 16, pad_value = 0.8)