From 35d5f95f51ff776e77de8503f015b0c74ed47285 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Mon, 12 Apr 2021 10:23:02 +0200 Subject: [PATCH] Update. --- autoencoder.py | 165 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100755 autoencoder.py diff --git a/autoencoder.py b/autoencoder.py new file mode 100755 index 0000000..50f9d10 --- /dev/null +++ b/autoencoder.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python + +# @XREMOTE_HOST: elk.fleuret.org +# @XREMOTE_EXEC: /home/fleuret/conda/bin/python +# @XREMOTE_PRE: killall -q -9 python || true +# @XREMOTE_PRE: ln -sf /home/fleuret/data/pytorch ./data +# @XREMOTE_GET: *.log *.dat *.png *.pth + +import sys, argparse, os, time + +import torch, torchvision + +from torch import optim, nn +from torch.nn import functional as F + +import torchvision + +###################################################################### + +if torch.cuda.is_available(): + device = torch.device('cuda') +else: + device = torch.device('cpu') + +###################################################################### + +parser = argparse.ArgumentParser(description = 'Simple auto-encoder.') + +parser.add_argument('--nb_epochs', + type = int, default = 25) + +parser.add_argument('--batch_size', + type = int, default = 100) + +parser.add_argument('--data_dir', + type = str, default = './data/') + +parser.add_argument('--log_filename', + type = str, default = 'train.log') + +parser.add_argument('--embedding_dim', + type = int, default = 16) + +parser.add_argument('--nb_channels', + type = int, default = 32) + +parser.add_argument('--force_train', + type = bool, default = False) + +args = parser.parse_args() + +log_file = open(args.log_filename, 'w') + +###################################################################### + +def log_string(s, color = None): + t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime()) + + if log_file is not None: + log_file.write(t + s + '\n') + log_file.flush() + + print(t + s) + sys.stdout.flush() + +###################################################################### + +class AutoEncoder(nn.Module): + def __init__(self, nb_channels, embedding_dim): + super(AutoEncoder, self).__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(1, nb_channels, kernel_size = 5), # to 24x24 + nn.ReLU(inplace = True), + nn.Conv2d(nb_channels, nb_channels, kernel_size = 5), # to 20x20 + nn.ReLU(inplace = True), + nn.Conv2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # to 9x9 + nn.ReLU(inplace = True), + nn.Conv2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # to 4x4 + nn.ReLU(inplace = True), + nn.Conv2d(nb_channels, embedding_dim, kernel_size = 4) + ) + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size = 4), + nn.ReLU(inplace = True), + nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # from 4x4 + nn.ReLU(inplace = True), + nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # from 9x9 + nn.ReLU(inplace = True), + nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 5), # from 20x20 + nn.ReLU(inplace = True), + nn.ConvTranspose2d(nb_channels, 1, kernel_size = 5), # from 24x24 + ) + + def encode(self, x): + return self.encoder(x).view(x.size(0), -1) + + def decode(self, z): + return self.decoder(z.view(z.size(0), -1, 1, 1)) + + def forward(self, x): + x = self.encoder(x) + # print(x.size()) + x = self.decoder(x) + return x + +###################################################################### + +train_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/', + train = True, download = True) +train_input = train_set.data.view(-1, 1, 28, 28).float() + +test_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/', + train = False, download = True) +test_input = test_set.data.view(-1, 1, 28, 28).float() + +###################################################################### + +train_input, test_input = train_input.to(device), test_input.to(device) + +mu, std = train_input.mean(), train_input.std() +train_input.sub_(mu).div_(std) +test_input.sub_(mu).div_(std) + +model = AutoEncoder(args.nb_channels, args.embedding_dim) +optimizer = optim.Adam(model.parameters(), lr = 1e-3) + +model.to(device) + +for epoch in range(args.nb_epochs): + acc_loss = 0 + for input in train_input.split(args.batch_size): + input = input.to(device) + z = model.encode(input) + output = model.decode(z) + loss = 0.5 * (output - input).pow(2).sum() / input.size(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + acc_loss += loss.item() + + log_string(f'acc_loss {epoch} {acc_loss}', 'blue') + +###################################################################### + +input = test_input[:256] +z = model.encode(input) +output = model.decode(z) + +torchvision.utils.save_image(1 - input, 'ae-input.png', nrow = 16, pad_value = 0.8) +torchvision.utils.save_image(1 - output, 'ae-output.png', nrow = 16, pad_value = 0.8) + +###################################################################### + +input = train_input[:256] +z = model.encode(input) +mu, std = z.mean(0), z.std(0) +z = z.normal_() * std + mu +output = model.decode(z) +torchvision.utils.save_image(1 - output, 'ae-synth.png', nrow = 16, pad_value = 0.8) + +###################################################################### -- 2.39.5