From: Francois Fleuret Date: Tue, 7 Dec 2021 07:51:04 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=d09d91f2b5b594f91a757134c5ce014ae8d68a9a;p=pytorch.git Update. --- diff --git a/autoencoder.py b/autoencoder.py deleted file mode 100755 index 22929af..0000000 --- a/autoencoder.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python - -import sys, argparse, os, time - -import torch, torchvision - -from torch import optim, nn -from torch.nn import functional as F - -import torchvision - -###################################################################### - -if torch.cuda.is_available(): - device = torch.device('cuda') -else: - device = torch.device('cpu') - -###################################################################### - -parser = argparse.ArgumentParser(description = 'Simple auto-encoder.') - -parser.add_argument('--nb_epochs', - type = int, default = 25) - -parser.add_argument('--batch_size', - type = int, default = 100) - -parser.add_argument('--data_dir', - type = str, default = './data/') - -parser.add_argument('--log_filename', - type = str, default = 'train.log') - -parser.add_argument('--embedding_dim', - type = int, default = 16) - -parser.add_argument('--nb_channels', - type = int, default = 32) - -parser.add_argument('--force_train', - type = bool, default = False) - -args = parser.parse_args() - -log_file = open(args.log_filename, 'w') - -###################################################################### - -def log_string(s, color = None): - t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime()) - - if log_file is not None: - log_file.write(t + s + '\n') - log_file.flush() - - print(t + s) - sys.stdout.flush() - -###################################################################### - -class AutoEncoder(nn.Module): - def __init__(self, nb_channels, embedding_dim): - super(AutoEncoder, self).__init__() - - self.encoder = nn.Sequential( - nn.Conv2d(1, nb_channels, kernel_size = 5), # to 24x24 - nn.ReLU(inplace = True), - nn.Conv2d(nb_channels, nb_channels, kernel_size = 5), # to 20x20 - nn.ReLU(inplace = True), - nn.Conv2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # to 9x9 - nn.ReLU(inplace = True), - nn.Conv2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # to 4x4 - nn.ReLU(inplace = True), - nn.Conv2d(nb_channels, embedding_dim, kernel_size = 4) - ) - - self.decoder = nn.Sequential( - nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size = 4), - nn.ReLU(inplace = True), - nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # from 4x4 - nn.ReLU(inplace = True), - nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # from 9x9 - nn.ReLU(inplace = True), - nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 5), # from 20x20 - nn.ReLU(inplace = True), - nn.ConvTranspose2d(nb_channels, 1, kernel_size = 5), # from 24x24 - ) - - def encode(self, x): - return self.encoder(x).view(x.size(0), -1) - - def decode(self, z): - return self.decoder(z.view(z.size(0), -1, 1, 1)) - - def forward(self, x): - x = self.encoder(x) - # print(x.size()) - x = self.decoder(x) - return x - -###################################################################### - -train_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/', - train = True, download = True) -train_input = train_set.data.view(-1, 1, 28, 28).float() - -test_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/', - train = False, download = True) -test_input = test_set.data.view(-1, 1, 28, 28).float() - -###################################################################### - -train_input, test_input = train_input.to(device), test_input.to(device) - -mu, std = train_input.mean(), train_input.std() -train_input.sub_(mu).div_(std) -test_input.sub_(mu).div_(std) - -model = AutoEncoder(args.nb_channels, args.embedding_dim) -optimizer = optim.Adam(model.parameters(), lr = 1e-3) - -model.to(device) - -for epoch in range(args.nb_epochs): - acc_loss = 0 - for input in train_input.split(args.batch_size): - input = input.to(device) - z = model.encode(input) - output = model.decode(z) - loss = 0.5 * (output - input).pow(2).sum() / input.size(0) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - acc_loss += loss.item() - - log_string(f'acc_loss {epoch} {acc_loss}', 'blue') - -###################################################################### - -input = test_input[:256] -z = model.encode(input) -output = model.decode(z) - -torchvision.utils.save_image(1 - input, 'ae-input.png', nrow = 16, pad_value = 0.8) -torchvision.utils.save_image(1 - output, 'ae-output.png', nrow = 16, pad_value = 0.8) - -###################################################################### - -input = train_input[:256] -z = model.encode(input) -mu, std = z.mean(0), z.std(0) -z = z.normal_() * std + mu -output = model.decode(z) -torchvision.utils.save_image(1 - output, 'ae-synth.png', nrow = 16, pad_value = 0.8) - -######################################################################