X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=minidiffusion.py;h=65ca94737443bcf8bc179aef884ceb30f6897886;hb=a810bbe6c5bc84f66e4fdb85dca41a39bd71afac;hp=2c54d196062ee1775385335d67c03ba29a34b3ca;hpb=a3c7617d0b5770edf6030502e4eac477a7218820;p=pytorch.git diff --git a/minidiffusion.py b/minidiffusion.py index 2c54d19..65ca947 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -14,14 +14,20 @@ from torch import nn device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f'device {device}') + ###################################################################### def sample_gaussian_mixture(nb): p, std = 0.3, 0.2 - result = torch.empty(nb, 1).normal_(0, std) + result = torch.randn(nb, 1) * std result = result + torch.sign(torch.rand(result.size()) - p) / 2 return result +def sample_ramp(nb): + result = torch.min(torch.rand(nb, 1), torch.rand(nb, 1)) + return result + def sample_two_discs(nb): a = torch.rand(nb) * math.pi * 2 b = torch.rand(nb).sqrt() @@ -35,8 +41,9 @@ def sample_two_discs(nb): def sample_disc_grid(nb): a = torch.rand(nb) * math.pi * 2 b = torch.rand(nb).sqrt() - q = torch.randint(5, (nb,)) / 2.5 - 2 / 2.5 - r = torch.randint(5, (nb,)) / 2.5 - 2 / 2.5 + N = 4 + q = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2) + r = (torch.randint(N, (nb,)) - (N - 1) / 2) / ((N - 1) / 2) b = b * 0.1 result = torch.empty(nb, 2) result[:, 0] = a.cos() * b + q @@ -59,6 +66,7 @@ def sample_mnist(nb): samplers = { 'gaussian_mixture': sample_gaussian_mixture, + 'ramp': sample_ramp, 'two_discs': sample_two_discs, 'disc_grid': sample_disc_grid, 'spiral': sample_spiral, @@ -97,7 +105,7 @@ parser.add_argument('--learning_rate', parser.add_argument('--ema_decay', type = float, default = 0.9999, - help = 'EMA decay, < 0 is no EMA') + help = 'EMA decay, <= 0 is no EMA') data_list = ', '.join( [ str(k) for k in samplers ]) @@ -121,23 +129,20 @@ class EMA: def __init__(self, model, decay): self.model = model self.decay = decay - if self.decay < 0: return - self.ema = { } + self.mem = { } with torch.no_grad(): for p in model.parameters(): - self.ema[p] = p.clone() + self.mem[p] = p.clone() def step(self): - if self.decay < 0: return with torch.no_grad(): for p in self.model.parameters(): - self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p) + self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p) - def copy(self): - if self.decay < 0: return + def copy_to_model(self): with torch.no_grad(): for p in self.model.parameters(): - p.copy_(self.ema[p]) + p.copy_(self.mem[p]) ###################################################################### @@ -179,7 +184,7 @@ train_mean, train_std = train_input.mean(), train_input.std() # Model if train_input.dim() == 2: - nh = 64 + nh = 256 model = nn.Sequential( nn.Linear(train_input.size(1) + 1, nh), @@ -197,6 +202,28 @@ elif train_input.dim() == 4: model.to(device) +print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}') + +###################################################################### +# Generate + +def generate(size, alpha, alpha_bar, sigma, model, train_mean, train_std): + + with torch.no_grad(): + + x = torch.randn(size, device = device) + + for t in range(T-1, -1, -1): + z = torch.zeros_like(x) if t == 0 else torch.randn_like(x) + input = torch.cat((x, torch.full_like(x[:,:1], t / (T - 1) - 0.5)), 1) + x = 1/torch.sqrt(alpha[t]) \ + * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * model(input)) \ + + sigma[t] * z + + x = x * train_std + train_mean + + return x + ###################################################################### # Train @@ -206,7 +233,7 @@ alpha = 1 - beta alpha_bar = alpha.log().cumsum(0).exp() sigma = beta.sqrt() -ema = EMA(model, decay = args.ema_decay) +ema = EMA(model, decay = args.ema_decay) if args.ema_decay > 0 else None for k in range(args.nb_epochs): @@ -217,8 +244,8 @@ for k in range(args.nb_epochs): x0 = (x0 - train_mean) / train_std t = torch.randint(T, (x0.size(0),) + (1,) * (x0.dim() - 1), device = x0.device) eps = torch.randn_like(x0) - input = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps - input = torch.cat((input, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1) + xt = torch.sqrt(alpha_bar[t]) * x0 + torch.sqrt(1 - alpha_bar[t]) * eps + input = torch.cat((xt, t.expand_as(x0[:,:1]) / (T - 1) - 0.5), 1) loss = (eps - model(input)).pow(2).mean() acc_loss += loss.item() * x0.size(0) @@ -226,29 +253,11 @@ for k in range(args.nb_epochs): loss.backward() optimizer.step() - ema.step() + if ema is not None: ema.step() - if k%10 == 0: print(f'{k} {acc_loss / train_input.size(0)}') + print(f'{k} {acc_loss / train_input.size(0)}') -ema.copy() - -###################################################################### -# Generate - -def generate(size, model): - with torch.no_grad(): - x = torch.randn(size, device = device) - - for t in range(T-1, -1, -1): - z = torch.zeros_like(x) if t == 0 else torch.randn_like(x) - input = torch.cat((x, torch.full_like(x[:,:1], t / (T - 1) - 0.5)), 1) - x = 1/torch.sqrt(alpha[t]) \ - * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * model(input)) \ - + sigma[t] * z - - x = x * train_std + train_mean - - return x +if ema is not None: ema.copy_to_model() ###################################################################### # Plot @@ -256,14 +265,19 @@ def generate(size, model): model.eval() if train_input.dim() == 2: + fig = plt.figure() ax = fig.add_subplot(1, 1, 1) + # Nx1 -> histogram if train_input.size(1) == 1: - x = generate((10000, 1), model) + x = generate((10000, 1), alpha, alpha_bar, sigma, + model, train_mean, train_std) ax.set_xlim(-1.25, 1.25) + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) d = train_input.flatten().detach().to('cpu').numpy() ax.hist(d, 25, (-1, 1), @@ -277,21 +291,25 @@ if train_input.dim() == 2: ax.legend(frameon = False, loc = 2) + # Nx2 -> scatter plot elif train_input.size(1) == 2: - x = generate((1000, 2), model) + x = generate((1000, 2), alpha, alpha_bar, sigma, + model, train_mean, train_std) - ax.set_xlim(-1.25, 1.25) - ax.set_ylim(-1.25, 1.25) + ax.set_xlim(-1.5, 1.5) + ax.set_ylim(-1.5, 1.5) ax.set(aspect = 1) + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) - d = train_input[:x.size(0)].detach().to('cpu').numpy() + d = x.detach().to('cpu').numpy() ax.scatter(d[:, 0], d[:, 1], - color = 'lightblue', label = 'Train') + s = 2.0, color = 'red', label = 'Synthesis') - d = x.detach().to('cpu').numpy() + d = train_input[:x.size(0)].detach().to('cpu').numpy() ax.scatter(d[:, 0], d[:, 1], - facecolors = 'none', color = 'red', label = 'Synthesis') + s = 2.0, color = 'gray', label = 'Train') ax.legend(frameon = False, loc = 2) @@ -303,8 +321,11 @@ if train_input.dim() == 2: plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) plt.show() +# NxCxHxW -> image elif train_input.dim() == 4: - x = generate((128,) + train_input.size()[1:], model) + + x = generate((128,) + train_input.size()[1:], alpha, alpha_bar, sigma, + model, train_mean, train_std) x = 1 - x.clamp(min = 0, max = 255) / 255 torchvision.utils.save_image(x, f'diffusion_{args.data}.png', nrow = 16, pad_value = 0.8)