3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import sys, argparse, time
10 import torch, torchvision
12 from torch import optim, nn
13 from torch.nn import functional as F
15 ######################################################################
17 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19 ######################################################################
21 parser = argparse.ArgumentParser(description = 'Tiny LeNet-like auto-encoder.')
23 parser.add_argument('--nb_epochs',
24 type = int, default = 25)
26 parser.add_argument('--batch_size',
27 type = int, default = 100)
29 parser.add_argument('--data_dir',
30 type = str, default = './data/')
32 parser.add_argument('--log_filename',
33 type = str, default = 'train.log')
35 parser.add_argument('--embedding_dim',
36 type = int, default = 8)
38 parser.add_argument('--nb_channels',
39 type = int, default = 32)
41 args = parser.parse_args()
43 log_file = open(args.log_filename, 'w')
45 ######################################################################
48 t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
50 if log_file is not None:
51 log_file.write(t + s + '\n')
57 ######################################################################
59 class AutoEncoder(nn.Module):
60 def __init__(self, nb_channels, embedding_dim):
61 super(AutoEncoder, self).__init__()
63 self.encoder = nn.Sequential(
64 nn.Conv2d(1, nb_channels, kernel_size = 5), # to 24x24
65 nn.ReLU(inplace = True),
66 nn.Conv2d(nb_channels, nb_channels, kernel_size = 5), # to 20x20
67 nn.ReLU(inplace = True),
68 nn.Conv2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # to 9x9
69 nn.ReLU(inplace = True),
70 nn.Conv2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # to 4x4
71 nn.ReLU(inplace = True),
72 nn.Conv2d(nb_channels, embedding_dim, kernel_size = 4)
75 self.decoder = nn.Sequential(
76 nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size = 4),
77 nn.ReLU(inplace = True),
78 nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 3, stride = 2), # from 4x4
79 nn.ReLU(inplace = True),
80 nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 4, stride = 2), # from 9x9
81 nn.ReLU(inplace = True),
82 nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size = 5), # from 20x20
83 nn.ReLU(inplace = True),
84 nn.ConvTranspose2d(nb_channels, 1, kernel_size = 5), # from 24x24
88 return self.encoder(x).view(x.size(0), -1)
91 return self.decoder(z.view(z.size(0), -1, 1, 1))
98 ######################################################################
100 train_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/',
101 train = True, download = True)
102 train_input = train_set.data.view(-1, 1, 28, 28).float()
104 test_set = torchvision.datasets.MNIST(args.data_dir + '/mnist/',
105 train = False, download = True)
106 test_input = test_set.data.view(-1, 1, 28, 28).float()
108 ######################################################################
110 model = AutoEncoder(args.nb_channels, args.embedding_dim)
111 optimizer = optim.Adam(model.parameters(), lr = 1e-3)
115 train_input, test_input = train_input.to(device), test_input.to(device)
117 mu, std = train_input.mean(), train_input.std()
118 train_input.sub_(mu).div_(std)
119 test_input.sub_(mu).div_(std)
121 ######################################################################
123 for epoch in range(args.nb_epochs):
127 for input in train_input.split(args.batch_size):
128 output = model(input)
129 loss = 0.5 * (output - input).pow(2).sum() / input.size(0)
131 optimizer.zero_grad()
135 acc_loss += loss.item()
137 log_string('acc_loss {:d} {:f}.'.format(epoch, acc_loss))
139 ######################################################################
141 input = test_input[:256]
145 z = model.encode(input)
146 output = model.decode(z)
148 torchvision.utils.save_image(1 - input, 'ae-input.png', nrow = 16, pad_value = 0.8)
149 torchvision.utils.save_image(1 - output, 'ae-output.png', nrow = 16, pad_value = 0.8)
153 z = model.encode(input)
154 mu, std = z.mean(0), z.std(0)
155 z = z.normal_() * std + mu
156 output = model.decode(z)
158 torchvision.utils.save_image(1 - output, 'ae-synth.png', nrow = 16, pad_value = 0.8)
160 ######################################################################