+++ /dev/null
-#!/usr/bin/env python
-
-import sys, argparse, os, time
-
-import torch, torchvision
-
-from torch import optim, nn
-from torch.nn import functional as F
-
-import torchvision
-
-######################################################################
-
-if torch.cuda.is_available():
- device = torch.device('cuda')
-else:
- device = torch.device('cpu')
-
-######################################################################
-
-parser = argparse.ArgumentParser(description = 'Simple auto-encoder.')
-
-parser.add_argument('--nb_epochs',
- type = int, default = 25)
-
-parser.add_argument('--batch_size',
- type = int, default = 100)
-
-parser.add_argument('--data_dir',
- type = str, default = './data/')
-
-parser.add_argument('--log_filename',
- type = str, default = 'train.log')
-
-parser.add_argument('--embedding_dim',
- type = int, default = 16)
-
-parser.add_argument('--nb_channels',
- type = int, default = 32)
-
-parser.add_argument('--force_train',
- type = bool, default = False)
-
-args = parser.parse_args()
-
-log_file = open(args.log_filename, 'w')
-
-######################################################################
-
-def log_string(s, color = None):
- t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
-
- if log_file is not None:
- log_file.write(t + s + '\n')
- log_file.flush()
-
- print(t + s)
- sys.stdout.flush()
-
-######################################################################
-
-class AutoEncoder(nn.Module):
- def __init__(self, nb_channels, embedding_dim):
- super(AutoEncoder, self).__init__()
-
- self.encoder = nn.Sequential(
- nn.Conv2d(1, nb_channels, kernel_size = 5), # to 24x24
- nn.ReLU(inplace = True),
- nn.Conv2d(nb_channels, nb_channels, kernel_size = 5), # to 20x20
- nn.ReLU(inplace = True),
- nn.Conv2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # to 9x9
- nn.ReLU(inplace = True),
- nn.Conv2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # to 4x4
- nn.ReLU(inplace = True),
- nn.Conv2d(nb_channels, embedding_dim, kernel_size = 4)
- )
-
- self.decoder = nn.Sequential(
- nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size = 4),
- nn.ReLU(inplace = True),
- nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # from 4x4
- nn.ReLU(inplace = True),
- nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # from 9x9
- nn.ReLU(inplace = True),
- nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 5), # from 20x20
- nn.ReLU(inplace = True),
- nn.ConvTranspose2d(nb_channels, 1, kernel_size = 5), # from 24x24
- )
-
- def encode(self, x):
- return self.encoder(x).view(x.size(0), -1)
-
- def decode(self, z):
- return self.decoder(z.view(z.size(0), -1, 1, 1))
-
- def forward(self, x):
- x = self.encoder(x)
- # print(x.size())
- x = self.decoder(x)
- return x
-
-######################################################################
-
-train_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/',
- train = True, download = True)
-train_input = train_set.data.view(-1, 1, 28, 28).float()
-
-test_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/',
- train = False, download = True)
-test_input = test_set.data.view(-1, 1, 28, 28).float()
-
-######################################################################
-
-train_input, test_input = train_input.to(device), test_input.to(device)
-
-mu, std = train_input.mean(), train_input.std()
-train_input.sub_(mu).div_(std)
-test_input.sub_(mu).div_(std)
-
-model = AutoEncoder(args.nb_channels, args.embedding_dim)
-optimizer = optim.Adam(model.parameters(), lr = 1e-3)
-
-model.to(device)
-
-for epoch in range(args.nb_epochs):
- acc_loss = 0
- for input in train_input.split(args.batch_size):
- input = input.to(device)
- z = model.encode(input)
- output = model.decode(z)
- loss = 0.5 * (output - input).pow(2).sum() / input.size(0)
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- acc_loss += loss.item()
-
- log_string(f'acc_loss {epoch} {acc_loss}', 'blue')
-
-######################################################################
-
-input = test_input[:256]
-z = model.encode(input)
-output = model.decode(z)
-
-torchvision.utils.save_image(1 - input, 'ae-input.png', nrow = 16, pad_value = 0.8)
-torchvision.utils.save_image(1 - output, 'ae-output.png', nrow = 16, pad_value = 0.8)
-
-######################################################################
-
-input = train_input[:256]
-z = model.encode(input)
-mu, std = z.mean(0), z.std(0)
-z = z.normal_() * std + mu
-output = model.decode(z)
-torchvision.utils.save_image(1 - output, 'ae-synth.png', nrow = 16, pad_value = 0.8)
-
-######################################################################