From 2d95f238bbaa0e585b50846d39c98df4aae2b7f9 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Sun, 6 Jun 2021 15:55:07 +0200 Subject: [PATCH 01/16] Update. --- conv_chain.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/conv_chain.py b/conv_chain.py index 184e06b..2d5af8e 100755 --- a/conv_chain.py +++ b/conv_chain.py @@ -14,10 +14,10 @@ def conv_chain(input_size, output_size, depth, cond): else: r = [ ] for kernel_size in range(1, input_size + 1): - for stride in range(1, input_size + 1): + for stride in range(1, input_size): if cond(depth, kernel_size, stride): n = (input_size - kernel_size) // stride + 1 - if (n - 1) * stride + kernel_size == input_size: + if n >= output_size and (n - 1) * stride + kernel_size == input_size: q = conv_chain(n, output_size, depth - 1, cond) r += [ [ (kernel_size, stride) ] + u for u in q ] return r -- 2.20.1 From 0c2da9bc90d21bdbcbcbeeb0a070653cbc3c25cc Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Sun, 6 Jun 2021 16:08:03 +0200 Subject: [PATCH 02/16] Update. --- conv_chain.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/conv_chain.py b/conv_chain.py index 2d5af8e..d10798f 100755 --- a/conv_chain.py +++ b/conv_chain.py @@ -1,7 +1,9 @@ #!/usr/bin/env python -import torch -from torch import nn +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret ###################################################################### @@ -26,13 +28,16 @@ def conv_chain(input_size, output_size, depth, cond): if __name__ == "__main__": + import torch + from torch import nn + # Example c = conv_chain( input_size = 64, output_size = 8, depth = 5, # We want kernels smaller than 4, strides smaller than the - # kernels, and stride of 1 except in the two last layers + # kernels, and strides of 1 except in the two last layers cond = lambda d, k, s: k <= 4 and s <= k and (s == 1 or d <= 2) ) -- 2.20.1 From dcb8e93a2f882abf1a30326fe419a592484deb18 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 7 Dec 2021 08:19:33 +0100 Subject: [PATCH 03/16] Initial commit. --- tinyae.py | 163 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100755 tinyae.py diff --git a/tinyae.py b/tinyae.py new file mode 100755 index 0000000..c608c9c --- /dev/null +++ b/tinyae.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import sys, argparse, time + +import torch, torchvision + +from torch import optim, nn +from torch.nn import functional as F + +###################################################################### + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +###################################################################### + +parser = argparse.ArgumentParser(description = 'Tiny LeNet-like auto-encoder.') + +parser.add_argument('--nb_epochs', + type = int, default = 25) + +parser.add_argument('--batch_size', + type = int, default = 100) + +parser.add_argument('--data_dir', + type = str, default = './data/') + +parser.add_argument('--log_filename', + type = str, default = 'train.log') + +parser.add_argument('--embedding_dim', + type = int, default = 8) + +parser.add_argument('--nb_channels', + type = int, default = 32) + +parser.add_argument('--force_train', + type = bool, default = False) + +args = parser.parse_args() + +log_file = open(args.log_filename, 'w') + +###################################################################### + +def log_string(s): + t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime()) + + if log_file is not None: + log_file.write(t + s + '\n') + log_file.flush() + + print(t + s) + sys.stdout.flush() + +###################################################################### + +class AutoEncoder(nn.Module): + def __init__(self, nb_channels, embedding_dim): + super(AutoEncoder, self).__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(1, nb_channels, kernel_size = 5), # to 24x24 + nn.ReLU(inplace = True), + nn.Conv2d(nb_channels, nb_channels, kernel_size = 5), # to 20x20 + nn.ReLU(inplace = True), + nn.Conv2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # to 9x9 + nn.ReLU(inplace = True), + nn.Conv2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # to 4x4 + nn.ReLU(inplace = True), + nn.Conv2d(nb_channels, embedding_dim, kernel_size = 4) + ) + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size = 4), + nn.ReLU(inplace = True), + nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # from 4x4 + nn.ReLU(inplace = True), + nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # from 9x9 + nn.ReLU(inplace = True), + nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 5), # from 20x20 + nn.ReLU(inplace = True), + nn.ConvTranspose2d(nb_channels, 1, kernel_size = 5), # from 24x24 + ) + + def encode(self, x): + return self.encoder(x).view(x.size(0), -1) + + def decode(self, z): + return self.decoder(z.view(z.size(0), -1, 1, 1)) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + +###################################################################### + +train_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/', + train = True, download = True) +train_input = train_set.data.view(-1, 1, 28, 28).float() + +test_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/', + train = False, download = True) +test_input = test_set.data.view(-1, 1, 28, 28).float() + +###################################################################### + +model = AutoEncoder(args.nb_channels, args.embedding_dim) +optimizer = optim.Adam(model.parameters(), lr = 1e-3) + +model.to(device) + +train_input, test_input = train_input.to(device), test_input.to(device) + +mu, std = train_input.mean(), train_input.std() +train_input.sub_(mu).div_(std) +test_input.sub_(mu).div_(std) + +###################################################################### + +for epoch in range(args.nb_epochs): + + acc_loss = 0 + + for input in train_input.split(args.batch_size): + output = model(input) + loss = 0.5 * (output - input).pow(2).sum() / input.size(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + acc_loss += loss.item() + + log_string('acc_loss {:d} {:f}.'.format(epoch, acc_loss)) + +###################################################################### + +input = test_input[:256] + +# Encode / decode + +z = model.encode(input) +output = model.decode(z) + +torchvision.utils.save_image(1 - input, 'ae-input.png', nrow = 16, pad_value = 0.8) +torchvision.utils.save_image(1 - output, 'ae-output.png', nrow = 16, pad_value = 0.8) + +# Dumb synthesis + +z = model.encode(input) +mu, std = z.mean(0), z.std(0) +z = z.normal_() * std + mu +output = model.decode(z) + +torchvision.utils.save_image(1 - output, 'ae-synth.png', nrow = 16, pad_value = 0.8) + +###################################################################### -- 2.20.1 From 1f7353f882933264f76a4584ba8cf36f055e4658 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 7 Dec 2021 08:19:35 +0100 Subject: [PATCH 04/16] Update. --- conv_chain.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/conv_chain.py b/conv_chain.py index d10798f..3077874 100755 --- a/conv_chain.py +++ b/conv_chain.py @@ -7,8 +7,8 @@ ###################################################################### -def conv_chain(input_size, output_size, depth, cond): - if depth == 0: +def conv_chain(input_size, output_size, remain_depth, cond): + if remain_depth == 0: if input_size == output_size: return [ [ ] ] else: @@ -17,10 +17,10 @@ def conv_chain(input_size, output_size, depth, cond): r = [ ] for kernel_size in range(1, input_size + 1): for stride in range(1, input_size): - if cond(depth, kernel_size, stride): + if cond(remain_depth, kernel_size, stride): n = (input_size - kernel_size) // stride + 1 if n >= output_size and (n - 1) * stride + kernel_size == input_size: - q = conv_chain(n, output_size, depth - 1, cond) + q = conv_chain(n, output_size, remain_depth - 1, cond) r += [ [ (kernel_size, stride) ] + u for u in q ] return r @@ -35,7 +35,7 @@ if __name__ == "__main__": c = conv_chain( input_size = 64, output_size = 8, - depth = 5, + remain_depth = 5, # We want kernels smaller than 4, strides smaller than the # kernels, and strides of 1 except in the two last layers cond = lambda d, k, s: k <= 4 and s <= k and (s == 1 or d <= 2) -- 2.20.1 From d09d91f2b5b594f91a757134c5ce014ae8d68a9a Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 7 Dec 2021 08:51:04 +0100 Subject: [PATCH 05/16] Update. --- autoencoder.py | 159 ------------------------------------------------- 1 file changed, 159 deletions(-) delete mode 100755 autoencoder.py diff --git a/autoencoder.py b/autoencoder.py deleted file mode 100755 index 22929af..0000000 --- a/autoencoder.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python - -import sys, argparse, os, time - -import torch, torchvision - -from torch import optim, nn -from torch.nn import functional as F - -import torchvision - -###################################################################### - -if torch.cuda.is_available(): - device = torch.device('cuda') -else: - device = torch.device('cpu') - -###################################################################### - -parser = argparse.ArgumentParser(description = 'Simple auto-encoder.') - -parser.add_argument('--nb_epochs', - type = int, default = 25) - -parser.add_argument('--batch_size', - type = int, default = 100) - -parser.add_argument('--data_dir', - type = str, default = './data/') - -parser.add_argument('--log_filename', - type = str, default = 'train.log') - -parser.add_argument('--embedding_dim', - type = int, default = 16) - -parser.add_argument('--nb_channels', - type = int, default = 32) - -parser.add_argument('--force_train', - type = bool, default = False) - -args = parser.parse_args() - -log_file = open(args.log_filename, 'w') - -###################################################################### - -def log_string(s, color = None): - t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime()) - - if log_file is not None: - log_file.write(t + s + '\n') - log_file.flush() - - print(t + s) - sys.stdout.flush() - -###################################################################### - -class AutoEncoder(nn.Module): - def __init__(self, nb_channels, embedding_dim): - super(AutoEncoder, self).__init__() - - self.encoder = nn.Sequential( - nn.Conv2d(1, nb_channels, kernel_size = 5), # to 24x24 - nn.ReLU(inplace = True), - nn.Conv2d(nb_channels, nb_channels, kernel_size = 5), # to 20x20 - nn.ReLU(inplace = True), - nn.Conv2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # to 9x9 - nn.ReLU(inplace = True), - nn.Conv2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # to 4x4 - nn.ReLU(inplace = True), - nn.Conv2d(nb_channels, embedding_dim, kernel_size = 4) - ) - - self.decoder = nn.Sequential( - nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size = 4), - nn.ReLU(inplace = True), - nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # from 4x4 - nn.ReLU(inplace = True), - nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # from 9x9 - nn.ReLU(inplace = True), - nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 5), # from 20x20 - nn.ReLU(inplace = True), - nn.ConvTranspose2d(nb_channels, 1, kernel_size = 5), # from 24x24 - ) - - def encode(self, x): - return self.encoder(x).view(x.size(0), -1) - - def decode(self, z): - return self.decoder(z.view(z.size(0), -1, 1, 1)) - - def forward(self, x): - x = self.encoder(x) - # print(x.size()) - x = self.decoder(x) - return x - -###################################################################### - -train_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/', - train = True, download = True) -train_input = train_set.data.view(-1, 1, 28, 28).float() - -test_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/', - train = False, download = True) -test_input = test_set.data.view(-1, 1, 28, 28).float() - -###################################################################### - -train_input, test_input = train_input.to(device), test_input.to(device) - -mu, std = train_input.mean(), train_input.std() -train_input.sub_(mu).div_(std) -test_input.sub_(mu).div_(std) - -model = AutoEncoder(args.nb_channels, args.embedding_dim) -optimizer = optim.Adam(model.parameters(), lr = 1e-3) - -model.to(device) - -for epoch in range(args.nb_epochs): - acc_loss = 0 - for input in train_input.split(args.batch_size): - input = input.to(device) - z = model.encode(input) - output = model.decode(z) - loss = 0.5 * (output - input).pow(2).sum() / input.size(0) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - acc_loss += loss.item() - - log_string(f'acc_loss {epoch} {acc_loss}', 'blue') - -###################################################################### - -input = test_input[:256] -z = model.encode(input) -output = model.decode(z) - -torchvision.utils.save_image(1 - input, 'ae-input.png', nrow = 16, pad_value = 0.8) -torchvision.utils.save_image(1 - output, 'ae-output.png', nrow = 16, pad_value = 0.8) - -###################################################################### - -input = train_input[:256] -z = model.encode(input) -mu, std = z.mean(0), z.std(0) -z = z.normal_() * std + mu -output = model.decode(z) -torchvision.utils.save_image(1 - output, 'ae-synth.png', nrow = 16, pad_value = 0.8) - -###################################################################### -- 2.20.1 From 5be92059a4bc81db8aad8b677d3387800de2aae8 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Wed, 8 Dec 2021 14:42:45 +0100 Subject: [PATCH 06/16] Update. --- tinyae.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tinyae.py b/tinyae.py index c608c9c..70484f1 100755 --- a/tinyae.py +++ b/tinyae.py @@ -38,9 +38,6 @@ parser.add_argument('--embedding_dim', parser.add_argument('--nb_channels', type = int, default = 32) -parser.add_argument('--force_train', - type = bool, default = False) - args = parser.parse_args() log_file = open(args.log_filename, 'w') -- 2.20.1 From c16fa89db08b59e454c6ca4b5c68bf7396e876dc Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 10 Dec 2021 22:53:37 +0100 Subject: [PATCH 07/16] Update. --- elbo.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100755 elbo.py diff --git a/elbo.py b/elbo.py new file mode 100755 index 0000000..24155fe --- /dev/null +++ b/elbo.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch + +def D_KL(p, q): + return - p @ (q / p).log() + +# p(X = x, Z = z) = p[x, z] +p = torch.rand(5, 4) +p /= p.sum() + +q = torch.rand(p.size()) +q /= q.sum() + +p_X = p.sum(1) +p_Z = p.sum(0) +p_X_given_Z = p / p.sum(0, keepdim = True) +p_Z_given_X = p / p.sum(1, keepdim = True) +q_X_given_Z = q / q.sum(0, keepdim = True) +q_Z_given_X = q / q.sum(1, keepdim = True) + +for x in range(p.size(0)): + elbo = q_Z_given_X[x, :] @ ( p_X_given_Z[x, :] / q_Z_given_X[x, :] * p_Z).log() + print(p_X[x].log(), elbo + D_KL(q_Z_given_X[x, :], p_Z_given_X[x, :])) -- 2.20.1 From 47525ec795faca1ab72aee13956a553d070c81b7 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Mon, 14 Mar 2022 13:22:13 +0100 Subject: [PATCH 08/16] Update. --- ddpol.py | 9 +++++---- elbo.py | 27 ++++++++++++++------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/ddpol.py b/ddpol.py index f33b0a1..645f47c 100755 --- a/ddpol.py +++ b/ddpol.py @@ -50,12 +50,13 @@ def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12): r = q.view(-1, 1) beta = x.new_zeros(D + 1, D + 1) beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3) - l, U = beta.eig(eigenvectors = True) - Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5) # clamp deals with ~0 negative values + W = torch.linalg.eig(beta) + l, U = W.eigenvalues.real, W.eigenvectors.real + Q = U @ torch.diag(l.clamp(min = 0) ** 0.5) # clamp deals with ~0 negative values B = torch.cat((B, y.new_zeros(Q.size(0))), 0) M = torch.cat((M, math.sqrt(rho) * Q.t()), 0) - return torch.lstsq(B, M).solution[:D+1, 0] + return torch.linalg.lstsq(M, B).solution[:D+1] ###################################################################### @@ -99,7 +100,7 @@ ax.set_ylabel('MSE', labelpad = 10) ax.axvline(x = args.nb_train_samples - 1, color = 'gray', linewidth = 0.5, linestyle = '--') -ax.text(args.nb_train_samples - 1.2, 1e-4, 'Nb. params = nb. samples', +ax.text(args.nb_train_samples - 1.2, 1e-4, 'nb. params = nb. samples', fontsize = 10, color = 'gray', rotation = 90, rotation_mode='anchor') diff --git a/elbo.py b/elbo.py index 24155fe..6af4a77 100755 --- a/elbo.py +++ b/elbo.py @@ -7,23 +7,24 @@ import torch -def D_KL(p, q): - return - p @ (q / p).log() +def D_KL(a, b): + return - a @ (b / a).log() # p(X = x, Z = z) = p[x, z] -p = torch.rand(5, 4) -p /= p.sum() -q = torch.rand(p.size()) -q /= q.sum() +p_XZ = torch.rand(5, 4) +p_XZ /= p_XZ.sum() +q_XZ = torch.rand(p_XZ.size()) +q_XZ /= q_XZ.sum() -p_X = p.sum(1) -p_Z = p.sum(0) -p_X_given_Z = p / p.sum(0, keepdim = True) -p_Z_given_X = p / p.sum(1, keepdim = True) -q_X_given_Z = q / q.sum(0, keepdim = True) -q_Z_given_X = q / q.sum(1, keepdim = True) +p_X = p_XZ.sum(1) +p_Z = p_XZ.sum(0) +p_X_given_Z = p_XZ / p_XZ.sum(0, keepdim = True) +p_Z_given_X = p_XZ / p_XZ.sum(1, keepdim = True) -for x in range(p.size(0)): +#q_X_given_Z = q_XZ / q_XZ.sum(0, keepdim = True) +q_Z_given_X = q_XZ / q_XZ.sum(1, keepdim = True) + +for x in range(p_XZ.size(0)): elbo = q_Z_given_X[x, :] @ ( p_X_given_Z[x, :] / q_Z_given_X[x, :] * p_Z).log() print(p_X[x].log(), elbo + D_KL(q_Z_given_X[x, :], p_Z_given_X[x, :])) -- 2.20.1 From bc937c74ad8cbeede2c2ae97cda72eaa3e9bb4f3 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Thu, 11 Aug 2022 22:44:23 +0200 Subject: [PATCH 09/16] Initial commit --- minidiffusion.py | 89 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100755 minidiffusion.py diff --git a/minidiffusion.py b/minidiffusion.py new file mode 100755 index 0000000..cbdb142 --- /dev/null +++ b/minidiffusion.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import matplotlib.pyplot as plt +import torch +from torch import nn + +###################################################################### + +def sample_phi(nb): + p, std = 0.3, 0.2 + result = torch.empty(nb).normal_(0, std) + result = result + torch.sign(torch.rand(result.size()) - p) / 2 + return result + +###################################################################### + +model = nn.Sequential( + nn.Linear(2, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 1), +) + +###################################################################### +# Train + +nb_samples = 25000 +nb_epochs = 250 +batch_size = 100 + +train_input = sample_phi(nb_samples)[:, None] + +T = 1000 +beta = torch.linspace(1e-4, 0.02, T) +alpha = 1 - beta +alpha_bar = alpha.log().cumsum(0).exp() +sigma = beta.sqrt() + +for k in range(nb_epochs): + acc_loss = 0 + optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4 * (1 - k / nb_epochs) ) + + for x0 in train_input.split(batch_size): + t = torch.randint(T, (x0.size(0), 1)) + eps = torch.randn(x0.size()) + input = alpha_bar[t].sqrt() * x0 + (1 - alpha_bar[t]).sqrt() * eps + input = torch.cat((input, 2 * t / T - 1), 1) + output = model(input) + loss = (eps - output).pow(2).mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() + + acc_loss += loss.item() + + if k%10 == 0: print(k, loss.item()) + +###################################################################### +# Plot + +x = torch.randn(10000, 1) + +for t in range(T-1, -1, -1): + z = torch.zeros(x.size()) if t == 0 else torch.randn(x.size()) + input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1) + x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) + sigma[t] * z + +fig = plt.figure() +ax = fig.add_subplot(1, 1, 1) +ax.set_xlim(-1.25, 1.25) + +d = train_input.flatten().detach().numpy() +ax.hist(d, 25, (-1, 1), histtype = 'stepfilled', color = 'lightblue', density = True, label = 'Train') + +d = x.flatten().detach().numpy() +ax.hist(d, 25, (-1, 1), histtype = 'step', color = 'red', density = True, label = 'Synthesis') + +filename = 'diffusion.pdf' +fig.savefig(filename, bbox_inches='tight') + +plt.show() + +###################################################################### -- 2.20.1 From b740a738f11ec566e99ac9c4f674119e7e9428b7 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Thu, 11 Aug 2022 22:52:34 +0200 Subject: [PATCH 10/16] Update. --- minidiffusion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/minidiffusion.py b/minidiffusion.py index cbdb142..ad1cda0 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -81,6 +81,8 @@ ax.hist(d, 25, (-1, 1), histtype = 'stepfilled', color = 'lightblue', density = d = x.flatten().detach().numpy() ax.hist(d, 25, (-1, 1), histtype = 'step', color = 'red', density = True, label = 'Synthesis') +ax.legend(frameon = False, loc = 2) + filename = 'diffusion.pdf' fig.savefig(filename, bbox_inches='tight') -- 2.20.1 From f27d6083fbe7243f5896ddd49587fe1923fe9a79 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 12 Aug 2022 09:57:09 +0200 Subject: [PATCH 11/16] Update. --- minidiffusion.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/minidiffusion.py b/minidiffusion.py index ad1cda0..037ef11 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -5,6 +5,11 @@ # Written by Francois Fleuret +# Minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel +# "Denoising Diffusion Probabilistic Models" (2020) +# +# https://arxiv.org/abs/2006.11239 + import matplotlib.pyplot as plt import torch from torch import nn @@ -62,7 +67,7 @@ for k in range(nb_epochs): if k%10 == 0: print(k, loss.item()) ###################################################################### -# Plot +# Generate x = torch.randn(10000, 1) @@ -71,19 +76,27 @@ for t in range(T-1, -1, -1): input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1) x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) + sigma[t] * z +###################################################################### +# Plot + fig = plt.figure() ax = fig.add_subplot(1, 1, 1) ax.set_xlim(-1.25, 1.25) d = train_input.flatten().detach().numpy() -ax.hist(d, 25, (-1, 1), histtype = 'stepfilled', color = 'lightblue', density = True, label = 'Train') +ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'stepfilled', color = 'lightblue', label = 'Train') d = x.flatten().detach().numpy() -ax.hist(d, 25, (-1, 1), histtype = 'step', color = 'red', density = True, label = 'Synthesis') +ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'step', color = 'red', label = 'Synthesis') ax.legend(frameon = False, loc = 2) filename = 'diffusion.pdf' +print(f'saving {filename}') fig.savefig(filename, bbox_inches='tight') plt.show() -- 2.20.1 From 317cc211cf9589a9eee5d937f0d0182719f24790 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 12 Aug 2022 10:01:40 +0200 Subject: [PATCH 12/16] Update. --- minidiffusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/minidiffusion.py b/minidiffusion.py index 037ef11..6855752 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -74,7 +74,8 @@ x = torch.randn(10000, 1) for t in range(T-1, -1, -1): z = torch.zeros(x.size()) if t == 0 else torch.randn(x.size()) input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1) - x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) + sigma[t] * z + x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) \ + + sigma[t] * z ###################################################################### # Plot -- 2.20.1 From 9338cb78aa1b8260d050615f5473c5a23cae3108 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 12 Aug 2022 10:32:43 +0200 Subject: [PATCH 13/16] Update. --- minidiffusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/minidiffusion.py b/minidiffusion.py index 6855752..a386a12 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -48,6 +48,7 @@ alpha_bar = alpha.log().cumsum(0).exp() sigma = beta.sqrt() for k in range(nb_epochs): + acc_loss = 0 optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4 * (1 - k / nb_epochs) ) -- 2.20.1 From be287add0311cc66345e5a26e297b4fe30310398 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 12 Aug 2022 23:05:22 +0200 Subject: [PATCH 14/16] Update. --- minidiffusion.py | 110 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 85 insertions(+), 25 deletions(-) diff --git a/minidiffusion.py b/minidiffusion.py index a386a12..0f5948e 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -10,36 +10,70 @@ # # https://arxiv.org/abs/2006.11239 +import math import matplotlib.pyplot as plt import torch from torch import nn ###################################################################### -def sample_phi(nb): +class EMA: + def __init__(self, model, decay = 0.9999): + self.model = model + self.decay = decay + self.ema = { } + with torch.no_grad(): + for p in model.parameters(): + self.ema[p] = p.clone() + + def step(self): + with torch.no_grad(): + for p in self.model.parameters(): + self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p) + + def copy(self): + with torch.no_grad(): + for p in self.model.parameters(): + p.copy_(self.ema[p]) + +###################################################################### + +def sample_gaussian_mixture(nb): p, std = 0.3, 0.2 - result = torch.empty(nb).normal_(0, std) + result = torch.empty(nb, 1).normal_(0, std) result = result + torch.sign(torch.rand(result.size()) - p) / 2 return result +def sample_arc(nb): + theta = torch.rand(nb) * math.pi + rho = torch.rand(nb) * 0.1 + 0.7 + result = torch.empty(nb, 2) + result[:, 0] = theta.cos() * rho + result[:, 1] = theta.sin() * rho + return result + +###################################################################### +# Train + +nb_samples = 25000 + +train_input = sample_gaussian_mixture(nb_samples) +#train_input = sample_arc(nb_samples) + ###################################################################### +nh = 64 + model = nn.Sequential( - nn.Linear(2, 32), + nn.Linear(train_input.size(1) + 1, nh), nn.ReLU(), - nn.Linear(32, 32), + nn.Linear(nh, nh), nn.ReLU(), - nn.Linear(32, 1), + nn.Linear(nh, train_input.size(1)), ) -###################################################################### -# Train - -nb_samples = 25000 -nb_epochs = 250 -batch_size = 100 - -train_input = sample_phi(nb_samples)[:, None] +nb_epochs = 50 +batch_size = 25 T = 1000 beta = torch.linspace(1e-4, 0.02, T) @@ -47,10 +81,12 @@ alpha = 1 - beta alpha_bar = alpha.log().cumsum(0).exp() sigma = beta.sqrt() +ema = EMA(model) + for k in range(nb_epochs): acc_loss = 0 - optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4 * (1 - k / nb_epochs) ) + optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) for x0 in train_input.split(batch_size): t = torch.randint(T, (x0.size(0), 1)) @@ -65,12 +101,16 @@ for k in range(nb_epochs): acc_loss += loss.item() + ema.step() + if k%10 == 0: print(k, loss.item()) +ema.copy() + ###################################################################### # Generate -x = torch.randn(10000, 1) +x = torch.randn(10000, train_input.size(1)) for t in range(T-1, -1, -1): z = torch.zeros(x.size()) if t == 0 else torch.randn(x.size()) @@ -83,24 +123,44 @@ for t in range(T-1, -1, -1): fig = plt.figure() ax = fig.add_subplot(1, 1, 1) -ax.set_xlim(-1.25, 1.25) -d = train_input.flatten().detach().numpy() -ax.hist(d, 25, (-1, 1), - density = True, - histtype = 'stepfilled', color = 'lightblue', label = 'Train') +if train_input.size(1) == 1: + + ax.set_xlim(-1.25, 1.25) + + d = train_input.flatten().detach().numpy() + ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'stepfilled', color = 'lightblue', label = 'Train') + + d = x.flatten().detach().numpy() + ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'step', color = 'red', label = 'Synthesis') + + ax.legend(frameon = False, loc = 2) + +elif train_input.size(1) == 2: + + ax.set_xlim(-1.25, 1.25) + ax.set_ylim(-1.25, 1.25) + ax.set(aspect = 1) + + d = train_input[:200].detach().numpy() + ax.scatter(d[:, 0], d[:, 1], + color = 'lightblue', label = 'Train') -d = x.flatten().detach().numpy() -ax.hist(d, 25, (-1, 1), - density = True, - histtype = 'step', color = 'red', label = 'Synthesis') + d = x[:200].detach().numpy() + ax.scatter(d[:, 0], d[:, 1], + color = 'red', label = 'Synthesis') -ax.legend(frameon = False, loc = 2) + ax.legend(frameon = False, loc = 2) filename = 'diffusion.pdf' print(f'saving {filename}') fig.savefig(filename, bbox_inches='tight') +plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) plt.show() ###################################################################### -- 2.20.1 From b52a28b72ae3a07f61aaa9fa5b6d063bbe5d4dda Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Sat, 13 Aug 2022 02:31:14 +0200 Subject: [PATCH 15/16] Add command line arguments and cuda support. --- minidiffusion.py | 166 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 117 insertions(+), 49 deletions(-) diff --git a/minidiffusion.py b/minidiffusion.py index 0f5948e..8d8dac0 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -5,60 +5,128 @@ # Written by Francois Fleuret -# Minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel -# "Denoising Diffusion Probabilistic Models" (2020) -# -# https://arxiv.org/abs/2006.11239 +import math, argparse -import math import matplotlib.pyplot as plt + import torch from torch import nn +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +###################################################################### + +def sample_gaussian_mixture(nb): + p, std = 0.3, 0.2 + result = torch.empty(nb, 1, device = device).normal_(0, std) + result = result + torch.sign(torch.rand(result.size(), device = device) - p) / 2 + return result + +def sample_arc(nb): + theta = torch.rand(nb, device = device) * math.pi + rho = torch.rand(nb, device = device) * 0.1 + 0.7 + result = torch.empty(nb, 2, device = device) + result[:, 0] = theta.cos() * rho + result[:, 1] = theta.sin() * rho + return result + +def sample_spiral(nb): + u = torch.rand(nb, device = device) + rho = u * 0.65 + 0.25 + torch.rand(nb, device = device) * 0.15 + theta = u * math.pi * 3 + result = torch.empty(nb, 2, device = device) + result[:, 0] = theta.cos() * rho + result[:, 1] = theta.sin() * rho + return result + +samplers = { + 'gaussian_mixture': sample_gaussian_mixture, + 'arc': sample_arc, + 'spiral': sample_spiral, +} + +###################################################################### + +parser = argparse.ArgumentParser( + description = '''A minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel +"Denoising Diffusion Probabilistic Models" (2020) +https://arxiv.org/abs/2006.11239''', + + formatter_class = argparse.ArgumentDefaultsHelpFormatter +) + +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed, < 0 is no seeding') + +parser.add_argument('--nb_epochs', + type = int, default = 100, + help = 'How many epochs') + +parser.add_argument('--batch_size', + type = int, default = 25, + help = 'Batch size') + +parser.add_argument('--nb_samples', + type = int, default = 25000, + help = 'Number of training examples') + +parser.add_argument('--learning_rate', + type = float, default = 1e-3, + help = 'Learning rate') + +parser.add_argument('--ema_decay', + type = float, default = 0.9999, + help = 'EMA decay, < 0 means no EMA') + +data_list = ', '.join( [ str(k) for k in samplers ]) + +parser.add_argument('--data', + type = str, default = 'gaussian_mixture', + help = f'Toy data-set to use: {data_list}') + +args = parser.parse_args() + +if args.seed >= 0: + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + # torch.use_deterministic_algorithms(True) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + ###################################################################### class EMA: - def __init__(self, model, decay = 0.9999): + def __init__(self, model, decay): self.model = model self.decay = decay + if self.decay < 0: return self.ema = { } with torch.no_grad(): for p in model.parameters(): self.ema[p] = p.clone() def step(self): + if self.decay < 0: return with torch.no_grad(): for p in self.model.parameters(): self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p) def copy(self): + if self.decay < 0: return with torch.no_grad(): for p in self.model.parameters(): p.copy_(self.ema[p]) -###################################################################### - -def sample_gaussian_mixture(nb): - p, std = 0.3, 0.2 - result = torch.empty(nb, 1).normal_(0, std) - result = result + torch.sign(torch.rand(result.size()) - p) / 2 - return result - -def sample_arc(nb): - theta = torch.rand(nb) * math.pi - rho = torch.rand(nb) * 0.1 + 0.7 - result = torch.empty(nb, 2) - result[:, 0] = theta.cos() * rho - result[:, 1] = theta.sin() * rho - return result - ###################################################################### # Train -nb_samples = 25000 - -train_input = sample_gaussian_mixture(nb_samples) -#train_input = sample_arc(nb_samples) +try: + train_input = samplers[args.data](args.nb_samples) +except KeyError: + print(f'unknown data {args.data}') + exit(1) ###################################################################### @@ -69,28 +137,27 @@ model = nn.Sequential( nn.ReLU(), nn.Linear(nh, nh), nn.ReLU(), + nn.Linear(nh, nh), + nn.ReLU(), nn.Linear(nh, train_input.size(1)), -) - -nb_epochs = 50 -batch_size = 25 +).to(device) T = 1000 -beta = torch.linspace(1e-4, 0.02, T) +beta = torch.linspace(1e-4, 0.02, T, device = device) alpha = 1 - beta alpha_bar = alpha.log().cumsum(0).exp() sigma = beta.sqrt() -ema = EMA(model) +ema = EMA(model, decay = args.ema_decay) -for k in range(nb_epochs): +for k in range(args.nb_epochs): acc_loss = 0 - optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) + optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) - for x0 in train_input.split(batch_size): - t = torch.randint(T, (x0.size(0), 1)) - eps = torch.randn(x0.size()) + for x0 in train_input.split(args.batch_size): + t = torch.randint(T, (x0.size(0), 1), device = device) + eps = torch.randn(x0.size(), device = device) input = alpha_bar[t].sqrt() * x0 + (1 - alpha_bar[t]).sqrt() * eps input = torch.cat((input, 2 * t / T - 1), 1) output = model(input) @@ -99,22 +166,22 @@ for k in range(nb_epochs): loss.backward() optimizer.step() - acc_loss += loss.item() + acc_loss += loss.item() * x0.size(0) ema.step() - if k%10 == 0: print(k, loss.item()) + if k%10 == 0: print(f'{k} {acc_loss / train_input.size(0)}') ema.copy() ###################################################################### # Generate -x = torch.randn(10000, train_input.size(1)) +x = torch.randn(10000, train_input.size(1), device = device) for t in range(T-1, -1, -1): - z = torch.zeros(x.size()) if t == 0 else torch.randn(x.size()) - input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1) + z = torch.zeros(x.size(), device = device) if t == 0 else torch.randn(x.size(), device = device) + input = torch.cat((x, torch.ones(x.size(0), 1, device = device) * 2 * t / T - 1), 1) x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) \ + sigma[t] * z @@ -128,12 +195,12 @@ if train_input.size(1) == 1: ax.set_xlim(-1.25, 1.25) - d = train_input.flatten().detach().numpy() + d = train_input.flatten().detach().to('cpu').numpy() ax.hist(d, 25, (-1, 1), density = True, histtype = 'stepfilled', color = 'lightblue', label = 'Train') - d = x.flatten().detach().numpy() + d = x.flatten().detach().to('cpu').numpy() ax.hist(d, 25, (-1, 1), density = True, histtype = 'step', color = 'red', label = 'Synthesis') @@ -146,21 +213,22 @@ elif train_input.size(1) == 2: ax.set_ylim(-1.25, 1.25) ax.set(aspect = 1) - d = train_input[:200].detach().numpy() + d = train_input[:200].detach().to('cpu').numpy() ax.scatter(d[:, 0], d[:, 1], color = 'lightblue', label = 'Train') - d = x[:200].detach().numpy() + d = x[:200].detach().to('cpu').numpy() ax.scatter(d[:, 0], d[:, 1], color = 'red', label = 'Synthesis') ax.legend(frameon = False, loc = 2) -filename = 'diffusion.pdf' +filename = f'diffusion_{args.data}.pdf' print(f'saving {filename}') fig.savefig(filename, bbox_inches='tight') -plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) -plt.show() +if hasattr(plt.get_current_fig_manager(), 'window'): + plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) + plt.show() ###################################################################### -- 2.20.1 From 142b09825bec53a432795cb34c2cc325b0e994c2 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Sat, 13 Aug 2022 02:32:35 +0200 Subject: [PATCH 16/16] Update. --- minidiffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/minidiffusion.py b/minidiffusion.py index 8d8dac0..075eb82 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -77,7 +77,7 @@ parser.add_argument('--learning_rate', parser.add_argument('--ema_decay', type = float, default = 0.9999, - help = 'EMA decay, < 0 means no EMA') + help = 'EMA decay, < 0 is no EMA') data_list = ', '.join( [ str(k) for k in samplers ]) -- 2.20.1