Update.
[pytorch.git] / causal-autoregression.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # ./causal-autoregression.py --data=toy1d
9 # ./causal-autoregression.py --data=toy1d --dilation
10 # ./causal-autoregression.py --data=mnist
11 # ./causal-autoregression.py --data=mnist --positional
12
13 import argparse, math, sys, time
14 import torch, torchvision
15
16 from torch import nn
17 from torch.nn import functional as F
18
19 ######################################################################
20
21
22 def save_images(x, filename, nrow=12):
23     print(f"Writing {filename}")
24     torchvision.utils.save_image(
25         x.narrow(0, 0, min(48, x.size(0))), filename, nrow=nrow, pad_value=1.0
26     )
27
28
29 ######################################################################
30
31 parser = argparse.ArgumentParser(
32     description="An implementation of a causal autoregression model",
33     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
34 )
35
36 parser.add_argument("--data", type=str, default="toy1d", help="What data")
37
38 parser.add_argument(
39     "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
40 )
41
42 parser.add_argument("--nb_epochs", type=int, default=-1, help="How many epochs")
43
44 parser.add_argument("--batch_size", type=int, default=100, help="Batch size")
45
46 parser.add_argument("--learning_rate", type=float, default=1e-3, help="Batch size")
47
48 parser.add_argument(
49     "--positional",
50     action="store_true",
51     default=False,
52     help="Do we provide a positional encoding as input",
53 )
54
55 parser.add_argument(
56     "--dilation",
57     action="store_true",
58     default=False,
59     help="Do we provide a positional encoding as input",
60 )
61
62 ######################################################################
63
64 args = parser.parse_args()
65
66 if args.seed >= 0:
67     torch.manual_seed(args.seed)
68
69 if args.nb_epochs < 0:
70     if args.data == "toy1d":
71         args.nb_epochs = 100
72     elif args.data == "mnist":
73         args.nb_epochs = 25
74
75 ######################################################################
76
77 if torch.cuda.is_available():
78     print("Cuda is available")
79     device = torch.device("cuda")
80     torch.backends.cudnn.benchmark = True
81 else:
82     device = torch.device("cpu")
83
84 ######################################################################
85
86
87 class NetToy1d(nn.Module):
88     def __init__(self, nb_classes, ks=2, nc=32):
89         super().__init__()
90         self.pad = (ks - 1, 0)
91         self.conv0 = nn.Conv1d(1, nc, kernel_size=1)
92         self.conv1 = nn.Conv1d(nc, nc, kernel_size=ks)
93         self.conv2 = nn.Conv1d(nc, nc, kernel_size=ks)
94         self.conv3 = nn.Conv1d(nc, nc, kernel_size=ks)
95         self.conv4 = nn.Conv1d(nc, nc, kernel_size=ks)
96         self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size=1)
97
98     def forward(self, x):
99         x = F.relu(self.conv0(F.pad(x, (1, -1))))
100         x = F.relu(self.conv1(F.pad(x, self.pad)))
101         x = F.relu(self.conv2(F.pad(x, self.pad)))
102         x = F.relu(self.conv3(F.pad(x, self.pad)))
103         x = F.relu(self.conv4(F.pad(x, self.pad)))
104         x = self.conv5(x)
105         return x.permute(0, 2, 1).contiguous()
106
107
108 class NetToy1dWithDilation(nn.Module):
109     def __init__(self, nb_classes, ks=2, nc=32):
110         super().__init__()
111         self.conv0 = nn.Conv1d(1, nc, kernel_size=1)
112         self.pad1 = ((ks - 1) * 2, 0)
113         self.conv1 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=2)
114         self.pad2 = ((ks - 1) * 4, 0)
115         self.conv2 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=4)
116         self.pad3 = ((ks - 1) * 8, 0)
117         self.conv3 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=8)
118         self.pad4 = ((ks - 1) * 16, 0)
119         self.conv4 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=16)
120         self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size=1)
121
122     def forward(self, x):
123         x = F.relu(self.conv0(F.pad(x, (1, -1))))
124         x = F.relu(self.conv1(F.pad(x, self.pad2)))
125         x = F.relu(self.conv2(F.pad(x, self.pad3)))
126         x = F.relu(self.conv3(F.pad(x, self.pad4)))
127         x = F.relu(self.conv4(F.pad(x, self.pad5)))
128         x = self.conv5(x)
129         return x.permute(0, 2, 1).contiguous()
130
131
132 ######################################################################
133
134
135 class PixelCNN(nn.Module):
136     def __init__(self, nb_classes, in_channels=1, ks=5):
137         super().__init__()
138
139         self.hpad = (ks // 2, ks // 2, ks // 2, 0)
140         self.vpad = (ks // 2, 0, 0, 0)
141
142         self.conv1h = nn.Conv2d(in_channels, 32, kernel_size=(ks // 2 + 1, ks))
143         self.conv2h = nn.Conv2d(32, 64, kernel_size=(ks // 2 + 1, ks))
144         self.conv1v = nn.Conv2d(in_channels, 32, kernel_size=(1, ks // 2 + 1))
145         self.conv2v = nn.Conv2d(32, 64, kernel_size=(1, ks // 2 + 1))
146         self.final1 = nn.Conv2d(128, 128, kernel_size=1)
147         self.final2 = nn.Conv2d(128, nb_classes, kernel_size=1)
148
149     def forward(self, x):
150         xh = F.pad(x, (0, 0, 1, -1))
151         xv = F.pad(x, (1, -1, 0, 0))
152         xh = F.relu(self.conv1h(F.pad(xh, self.hpad)))
153         xv = F.relu(self.conv1v(F.pad(xv, self.vpad)))
154         xh = F.relu(self.conv2h(F.pad(xh, self.hpad)))
155         xv = F.relu(self.conv2v(F.pad(xv, self.vpad)))
156         x = F.relu(self.final1(torch.cat((xh, xv), 1)))
157         x = self.final2(x)
158
159         return x.permute(0, 2, 3, 1).contiguous()
160
161
162 ######################################################################
163
164
165 def positional_tensor(height, width):
166     index_h = torch.arange(height).view(1, -1)
167     m_h = (2 ** torch.arange(math.ceil(math.log2(height)))).view(-1, 1)
168     b_h = (index_h // m_h) % 2
169     i_h = b_h[None, :, None, :].expand(-1, -1, height, -1)
170
171     index_w = torch.arange(width).view(1, -1)
172     m_w = (2 ** torch.arange(math.ceil(math.log2(width)))).view(-1, 1)
173     b_w = (index_w // m_w) % 2
174     i_w = b_w[None, :, :, None].expand(-1, -1, -1, width)
175
176     return torch.cat((i_w, i_h), 1)
177
178
179 ######################################################################
180
181 str_experiment = args.data
182
183 if args.positional:
184     str_experiment += "-positional"
185
186 if args.dilation:
187     str_experiment += "-dilation"
188
189 log_file = open("causalar-" + str_experiment + "-train.log", "w")
190
191
192 def log_string(s):
193     s = time.strftime("%Y%m%d-%H:%M:%S", time.localtime()) + " " + s
194     print(s)
195     log_file.write(s + "\n")
196     log_file.flush()
197
198
199 ######################################################################
200
201
202 def generate_sequences(nb, len):
203     nb_parts = 2
204
205     r = torch.empty(nb, len)
206
207     x = torch.empty(nb, nb_parts).uniform_(-1, 1)
208     x = x.view(nb, nb_parts, 1).expand(nb, nb_parts, len)
209     x = x * torch.linspace(0, len - 1, len).view(1, -1) + len
210
211     for n in range(nb):
212         a = torch.randperm(len - 2)[: nb_parts + 1].sort()[0]
213         a[0] = 0
214         a[a.size(0) - 1] = len
215         for k in range(a.size(0) - 1):
216             r[n, a[k] : a[k + 1]] = x[n, k, : a[k + 1] - a[k]]
217
218     return r.round().long()
219
220
221 ######################################################################
222
223 if args.data == "toy1d":
224     len = 32
225     train_input = generate_sequences(50000, len).to(device).unsqueeze(1)
226     if args.dilation:
227         model = NetToy1dWithDilation(nb_classes=2 * len).to(device)
228     else:
229         model = NetToy1d(nb_classes=2 * len).to(device)
230
231 elif args.data == "mnist":
232     train_set = torchvision.datasets.MNIST("./data/mnist/", train=True, download=True)
233     train_input = train_set.data.view(-1, 1, 28, 28).long().to(device)
234
235     model = PixelCNN(nb_classes=256, in_channels=1).to(device)
236     in_channels = train_input.size(1)
237
238     if args.positional:
239         height, width = train_input.size(2), train_input.size(3)
240         positional_input = positional_tensor(height, width).float().to(device)
241         in_channels += positional_input.size(1)
242
243     model = PixelCNN(nb_classes=256, in_channels=in_channels).to(device)
244
245 else:
246     raise ValueError("Unknown data " + args.data)
247
248 ######################################################################
249
250 mean, std = train_input.float().mean(), train_input.float().std()
251
252 nb_parameters = sum(t.numel() for t in model.parameters())
253 log_string(f"nb_parameters {nb_parameters}")
254
255 cross_entropy = nn.CrossEntropyLoss().to(device)
256 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
257
258 for e in range(args.nb_epochs):
259     nb_batches, acc_loss = 0, 0.0
260
261     for sequences in train_input.split(args.batch_size):
262         input = (sequences - mean) / std
263
264         if args.positional:
265             input = torch.cat(
266                 (input, positional_input.expand(input.size(0), -1, -1, -1)), 1
267             )
268
269         output = model(input)
270
271         loss = cross_entropy(output.view(-1, output.size(-1)), sequences.view(-1))
272
273         optimizer.zero_grad()
274         loss.backward()
275         optimizer.step()
276
277         nb_batches += 1
278         acc_loss += loss.item()
279
280     log_string(f"{e} {acc_loss / nb_batches} {math.exp(acc_loss / nb_batches)}")
281
282     sys.stdout.flush()
283
284 ######################################################################
285
286 generated = train_input.new_zeros((48,) + train_input.size()[1:])
287
288 flat = generated.view(generated.size(0), -1)
289
290 for t in range(flat.size(1)):
291     input = (generated.float() - mean) / std
292     if args.positional:
293         input = torch.cat(
294             (input, positional_input.expand(input.size(0), -1, -1, -1)), 1
295         )
296     output = model(input)
297     logits = output.view(flat.size() + (-1,))[:, t]
298     dist = torch.distributions.categorical.Categorical(logits=logits)
299     flat[:, t] = dist.sample()
300
301 ######################################################################
302
303 if args.data == "toy1d":
304     with open("causalar-" + str_experiment + "-train.dat", "w") as file:
305         for j in range(train_input.size(2)):
306             file.write(f"{j}")
307             for i in range(min(train_input.size(0), 25)):
308                 file.write(f" {train_input[i, 0, j]}")
309             file.write("\n")
310
311     with open("causalar-" + str_experiment + "-generated.dat", "w") as file:
312         for j in range(generated.size(2)):
313             file.write(f"{j}")
314             for i in range(generated.size(0)):
315                 file.write(f" {generated[i, 0, j]}")
316             file.write("\n")
317
318 elif args.data == "mnist":
319     img_train = 1 - train_input[: generated.size(0)].float() / 255
320     img_generated = 1 - generated.float() / 255
321
322     save_images(img_train, "causalar-" + str_experiment + "-train.png", nrow=12)
323     save_images(img_generated, "causalar-" + str_experiment + "-generated.png", nrow=12)
324
325 ######################################################################