Update.
[pytorch.git] / causal-autoregression.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # ./causal-autoregression.py --data=toy1d
9 # ./causal-autoregression.py --data=toy1d --dilation
10 # ./causal-autoregression.py --data=mnist
11 # ./causal-autoregression.py --data=mnist --positional
12
13 import argparse, math, sys, time
14 import torch, torchvision
15
16 from torch import nn
17 from torch.nn import functional as F
18
19 ######################################################################
20
21 def save_images(x, filename, nrow = 12):
22     print(f'Writing {filename}')
23     torchvision.utils.save_image(x.narrow(0,0, min(48, x.size(0))),
24                                  filename,
25                                  nrow = nrow, pad_value=1.0)
26
27 ######################################################################
28
29 parser = argparse.ArgumentParser(
30     description = 'An implementation of a causal autoregression model',
31     formatter_class = argparse.ArgumentDefaultsHelpFormatter
32 )
33
34 parser.add_argument('--data',
35                     type = str, default = 'toy1d',
36                     help = 'What data')
37
38 parser.add_argument('--seed',
39                     type = int, default = 0,
40                     help = 'Random seed (default 0, < 0 is no seeding)')
41
42 parser.add_argument('--nb_epochs',
43                     type = int, default = -1,
44                     help = 'How many epochs')
45
46 parser.add_argument('--batch_size',
47                     type = int, default = 100,
48                     help = 'Batch size')
49
50 parser.add_argument('--learning_rate',
51                     type = float, default = 1e-3,
52                     help = 'Batch size')
53
54 parser.add_argument('--positional',
55                     action='store_true', default = False,
56                     help = 'Do we provide a positional encoding as input')
57
58 parser.add_argument('--dilation',
59                     action='store_true', default = False,
60                     help = 'Do we provide a positional encoding as input')
61
62 ######################################################################
63
64 args = parser.parse_args()
65
66 if args.seed >= 0:
67     torch.manual_seed(args.seed)
68
69 if args.nb_epochs < 0:
70     if args.data == 'toy1d':
71         args.nb_epochs = 100
72     elif args.data == 'mnist':
73         args.nb_epochs = 25
74
75 ######################################################################
76
77 if torch.cuda.is_available():
78     print('Cuda is available')
79     device = torch.device('cuda')
80     torch.backends.cudnn.benchmark = True
81 else:
82     device = torch.device('cpu')
83
84 ######################################################################
85
86 class NetToy1d(nn.Module):
87     def __init__(self, nb_classes, ks = 2, nc = 32):
88         super().__init__()
89         self.pad = (ks - 1, 0)
90         self.conv0 = nn.Conv1d(1, nc, kernel_size = 1)
91         self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks)
92         self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks)
93         self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks)
94         self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks)
95         self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1)
96
97     def forward(self, x):
98         x = F.relu(self.conv0(F.pad(x, (1, -1))))
99         x = F.relu(self.conv1(F.pad(x, self.pad)))
100         x = F.relu(self.conv2(F.pad(x, self.pad)))
101         x = F.relu(self.conv3(F.pad(x, self.pad)))
102         x = F.relu(self.conv4(F.pad(x, self.pad)))
103         x = self.conv5(x)
104         return x.permute(0, 2, 1).contiguous()
105
106 class NetToy1dWithDilation(nn.Module):
107     def __init__(self, nb_classes, ks = 2, nc = 32):
108         super().__init__()
109         self.conv0 = nn.Conv1d(1, nc, kernel_size = 1)
110         self.pad1 = ((ks-1) * 2, 0)
111         self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 2)
112         self.pad2 = ((ks-1) * 4, 0)
113         self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 4)
114         self.pad3 = ((ks-1) * 8, 0)
115         self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 8)
116         self.pad4 = ((ks-1) * 16, 0)
117         self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 16)
118         self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1)
119
120     def forward(self, x):
121         x = F.relu(self.conv0(F.pad(x, (1, -1))))
122         x = F.relu(self.conv1(F.pad(x, self.pad2)))
123         x = F.relu(self.conv2(F.pad(x, self.pad3)))
124         x = F.relu(self.conv3(F.pad(x, self.pad4)))
125         x = F.relu(self.conv4(F.pad(x, self.pad5)))
126         x = self.conv5(x)
127         return x.permute(0, 2, 1).contiguous()
128
129 ######################################################################
130
131 class PixelCNN(nn.Module):
132     def __init__(self, nb_classes, in_channels = 1, ks = 5):
133         super().__init__()
134
135         self.hpad = (ks//2, ks//2, ks//2, 0)
136         self.vpad = (ks//2,     0,     0, 0)
137
138         self.conv1h = nn.Conv2d(in_channels, 32, kernel_size = (ks//2+1, ks))
139         self.conv2h = nn.Conv2d(32, 64, kernel_size = (ks//2+1, ks))
140         self.conv1v = nn.Conv2d(in_channels, 32, kernel_size = (1, ks//2+1))
141         self.conv2v = nn.Conv2d(32, 64, kernel_size = (1, ks//2+1))
142         self.final1 = nn.Conv2d(128, 128, kernel_size = 1)
143         self.final2 = nn.Conv2d(128, nb_classes, kernel_size = 1)
144
145     def forward(self, x):
146         xh = F.pad(x, (0, 0, 1, -1))
147         xv = F.pad(x, (1, -1, 0, 0))
148         xh = F.relu(self.conv1h(F.pad(xh, self.hpad)))
149         xv = F.relu(self.conv1v(F.pad(xv, self.vpad)))
150         xh = F.relu(self.conv2h(F.pad(xh, self.hpad)))
151         xv = F.relu(self.conv2v(F.pad(xv, self.vpad)))
152         x = F.relu(self.final1(torch.cat((xh, xv), 1)))
153         x = self.final2(x)
154
155         return x.permute(0, 2, 3, 1).contiguous()
156
157 ######################################################################
158
159 def positional_tensor(height, width):
160     index_h = torch.arange(height).view(1, -1)
161     m_h = (2 ** torch.arange(math.ceil(math.log2(height)))).view(-1, 1)
162     b_h = (index_h // m_h) % 2
163     i_h = b_h[None, :, None, :].expand(-1, -1, height, -1)
164
165     index_w = torch.arange(width).view(1, -1)
166     m_w = (2 ** torch.arange(math.ceil(math.log2(width)))).view(-1, 1)
167     b_w = (index_w // m_w) % 2
168     i_w = b_w[None, :, :, None].expand(-1, -1, -1, width)
169
170     return torch.cat((i_w, i_h), 1)
171
172 ######################################################################
173
174 str_experiment = args.data
175
176 if args.positional:
177     str_experiment += '-positional'
178
179 if args.dilation:
180     str_experiment += '-dilation'
181
182 log_file = open('causalar-' + str_experiment + '-train.log', 'w')
183
184 def log_string(s):
185     s = time.strftime("%Y%m%d-%H:%M:%S", time.localtime()) + ' ' + s
186     print(s)
187     log_file.write(s + '\n')
188     log_file.flush()
189
190 ######################################################################
191
192 def generate_sequences(nb, len):
193     nb_parts = 2
194
195     r = torch.empty(nb, len)
196
197     x = torch.empty(nb, nb_parts).uniform_(-1, 1)
198     x = x.view(nb, nb_parts, 1).expand(nb, nb_parts, len)
199     x = x * torch.linspace(0, len-1, len).view(1, -1) + len
200
201     for n in range(nb):
202         a = torch.randperm(len - 2)[:nb_parts+1].sort()[0]
203         a[0] = 0
204         a[a.size(0) - 1] = len
205         for k in range(a.size(0) - 1):
206             r[n, a[k]:a[k+1]] = x[n, k, :a[k+1]-a[k]]
207
208     return r.round().long()
209
210 ######################################################################
211
212 if args.data == 'toy1d':
213     len = 32
214     train_input = generate_sequences(50000, len).to(device).unsqueeze(1)
215     if args.dilation:
216         model = NetToy1dWithDilation(nb_classes = 2 * len).to(device)
217     else:
218         model = NetToy1d(nb_classes = 2 * len).to(device)
219
220 elif args.data == 'mnist':
221     train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
222     train_input = train_set.data.view(-1, 1, 28, 28).long().to(device)
223
224     model = PixelCNN(nb_classes = 256, in_channels = 1).to(device)
225     in_channels = train_input.size(1)
226
227     if args.positional:
228         height, width = train_input.size(2), train_input.size(3)
229         positional_input = positional_tensor(height, width).float().to(device)
230         in_channels += positional_input.size(1)
231
232     model = PixelCNN(nb_classes = 256, in_channels = in_channels).to(device)
233
234 else:
235     raise ValueError('Unknown data ' + args.data)
236
237 ######################################################################
238
239 mean, std = train_input.float().mean(), train_input.float().std()
240
241 nb_parameters = sum(t.numel() for t in model.parameters())
242 log_string(f'nb_parameters {nb_parameters}')
243
244 cross_entropy = nn.CrossEntropyLoss().to(device)
245 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
246
247 for e in range(args.nb_epochs):
248
249     nb_batches, acc_loss = 0, 0.0
250
251     for sequences in train_input.split(args.batch_size):
252         input = (sequences - mean)/std
253
254         if args.positional:
255             input = torch.cat(
256                 (input, positional_input.expand(input.size(0), -1, -1, -1)),
257                 1
258             )
259
260         output = model(input)
261
262         loss = cross_entropy(
263             output.view(-1, output.size(-1)),
264             sequences.view(-1)
265         )
266
267         optimizer.zero_grad()
268         loss.backward()
269         optimizer.step()
270
271         nb_batches += 1
272         acc_loss += loss.item()
273
274     log_string(f'{e} {acc_loss / nb_batches} {math.exp(acc_loss / nb_batches)}')
275
276     sys.stdout.flush()
277
278 ######################################################################
279
280 generated = train_input.new_zeros((48,) + train_input.size()[1:])
281
282 flat = generated.view(generated.size(0), -1)
283
284 for t in range(flat.size(1)):
285     input = (generated.float() - mean) / std
286     if args.positional:
287         input = torch.cat((input, positional_input.expand(input.size(0), -1, -1, -1)), 1)
288     output = model(input)
289     logits = output.view(flat.size() + (-1,))[:, t]
290     dist = torch.distributions.categorical.Categorical(logits = logits)
291     flat[:, t] = dist.sample()
292
293 ######################################################################
294
295 if args.data == 'toy1d':
296
297     with open('causalar-' + str_experiment + '-train.dat', 'w') as file:
298         for j in range(train_input.size(2)):
299             file.write(f'{j}')
300             for i in range(min(train_input.size(0), 25)):
301                 file.write(f' {train_input[i, 0, j]}')
302             file.write('\n')
303
304     with open('causalar-' + str_experiment + '-generated.dat', 'w') as file:
305         for j in range(generated.size(2)):
306             file.write(f'{j}')
307             for i in range(generated.size(0)):
308                 file.write(f' {generated[i, 0, j]}')
309             file.write('\n')
310
311 elif args.data == 'mnist':
312
313     img_train = 1 - train_input[:generated.size(0)].float() / 255
314     img_generated = 1 - generated.float() / 255
315
316     save_images(img_train, 'causalar-' + str_experiment + '-train.png', nrow = 12)
317     save_images(img_generated, 'causalar-' + str_experiment + '-generated.png', nrow = 12)
318
319 ######################################################################