From: Francois Fleuret Date: Fri, 12 Aug 2022 21:05:22 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=commitdiff_plain;h=be287add0311cc66345e5a26e297b4fe30310398 Update. --- diff --git a/minidiffusion.py b/minidiffusion.py index a386a12..0f5948e 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -10,36 +10,70 @@ # # https://arxiv.org/abs/2006.11239 +import math import matplotlib.pyplot as plt import torch from torch import nn ###################################################################### -def sample_phi(nb): +class EMA: + def __init__(self, model, decay = 0.9999): + self.model = model + self.decay = decay + self.ema = { } + with torch.no_grad(): + for p in model.parameters(): + self.ema[p] = p.clone() + + def step(self): + with torch.no_grad(): + for p in self.model.parameters(): + self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p) + + def copy(self): + with torch.no_grad(): + for p in self.model.parameters(): + p.copy_(self.ema[p]) + +###################################################################### + +def sample_gaussian_mixture(nb): p, std = 0.3, 0.2 - result = torch.empty(nb).normal_(0, std) + result = torch.empty(nb, 1).normal_(0, std) result = result + torch.sign(torch.rand(result.size()) - p) / 2 return result +def sample_arc(nb): + theta = torch.rand(nb) * math.pi + rho = torch.rand(nb) * 0.1 + 0.7 + result = torch.empty(nb, 2) + result[:, 0] = theta.cos() * rho + result[:, 1] = theta.sin() * rho + return result + +###################################################################### +# Train + +nb_samples = 25000 + +train_input = sample_gaussian_mixture(nb_samples) +#train_input = sample_arc(nb_samples) + ###################################################################### +nh = 64 + model = nn.Sequential( - nn.Linear(2, 32), + nn.Linear(train_input.size(1) + 1, nh), nn.ReLU(), - nn.Linear(32, 32), + nn.Linear(nh, nh), nn.ReLU(), - nn.Linear(32, 1), + nn.Linear(nh, train_input.size(1)), ) -###################################################################### -# Train - -nb_samples = 25000 -nb_epochs = 250 -batch_size = 100 - -train_input = sample_phi(nb_samples)[:, None] +nb_epochs = 50 +batch_size = 25 T = 1000 beta = torch.linspace(1e-4, 0.02, T) @@ -47,10 +81,12 @@ alpha = 1 - beta alpha_bar = alpha.log().cumsum(0).exp() sigma = beta.sqrt() +ema = EMA(model) + for k in range(nb_epochs): acc_loss = 0 - optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4 * (1 - k / nb_epochs) ) + optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) for x0 in train_input.split(batch_size): t = torch.randint(T, (x0.size(0), 1)) @@ -65,12 +101,16 @@ for k in range(nb_epochs): acc_loss += loss.item() + ema.step() + if k%10 == 0: print(k, loss.item()) +ema.copy() + ###################################################################### # Generate -x = torch.randn(10000, 1) +x = torch.randn(10000, train_input.size(1)) for t in range(T-1, -1, -1): z = torch.zeros(x.size()) if t == 0 else torch.randn(x.size()) @@ -83,24 +123,44 @@ for t in range(T-1, -1, -1): fig = plt.figure() ax = fig.add_subplot(1, 1, 1) -ax.set_xlim(-1.25, 1.25) -d = train_input.flatten().detach().numpy() -ax.hist(d, 25, (-1, 1), - density = True, - histtype = 'stepfilled', color = 'lightblue', label = 'Train') +if train_input.size(1) == 1: + + ax.set_xlim(-1.25, 1.25) + + d = train_input.flatten().detach().numpy() + ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'stepfilled', color = 'lightblue', label = 'Train') + + d = x.flatten().detach().numpy() + ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'step', color = 'red', label = 'Synthesis') + + ax.legend(frameon = False, loc = 2) + +elif train_input.size(1) == 2: + + ax.set_xlim(-1.25, 1.25) + ax.set_ylim(-1.25, 1.25) + ax.set(aspect = 1) + + d = train_input[:200].detach().numpy() + ax.scatter(d[:, 0], d[:, 1], + color = 'lightblue', label = 'Train') -d = x.flatten().detach().numpy() -ax.hist(d, 25, (-1, 1), - density = True, - histtype = 'step', color = 'red', label = 'Synthesis') + d = x[:200].detach().numpy() + ax.scatter(d[:, 0], d[:, 1], + color = 'red', label = 'Synthesis') -ax.legend(frameon = False, loc = 2) + ax.legend(frameon = False, loc = 2) filename = 'diffusion.pdf' print(f'saving {filename}') fig.savefig(filename, bbox_inches='tight') +plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) plt.show() ######################################################################