3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 # ./causal-autoregression.py --data=toy1d
9 # ./causal-autoregression.py --data=toy1d --dilation
10 # ./causal-autoregression.py --data=mnist
11 # ./causal-autoregression.py --data=mnist --positional
13 import argparse, math, sys, time
14 import torch, torchvision
17 from torch.nn import functional as F
19 ######################################################################
21 def save_images(x, filename, nrow = 12):
22 print(f'Writing {filename}')
23 torchvision.utils.save_image(x.narrow(0,0, min(48, x.size(0))),
25 nrow = nrow, pad_value=1.0)
27 ######################################################################
29 parser = argparse.ArgumentParser(
30 description = 'An implementation of a causal autoregression model',
31 formatter_class = argparse.ArgumentDefaultsHelpFormatter
34 parser.add_argument('--data',
35 type = str, default = 'toy1d',
38 parser.add_argument('--seed',
39 type = int, default = 0,
40 help = 'Random seed (default 0, < 0 is no seeding)')
42 parser.add_argument('--nb_epochs',
43 type = int, default = -1,
44 help = 'How many epochs')
46 parser.add_argument('--batch_size',
47 type = int, default = 100,
50 parser.add_argument('--learning_rate',
51 type = float, default = 1e-3,
54 parser.add_argument('--positional',
55 action='store_true', default = False,
56 help = 'Do we provide a positional encoding as input')
58 parser.add_argument('--dilation',
59 action='store_true', default = False,
60 help = 'Do we provide a positional encoding as input')
62 ######################################################################
64 args = parser.parse_args()
67 torch.manual_seed(args.seed)
69 if args.nb_epochs < 0:
70 if args.data == 'toy1d':
72 elif args.data == 'mnist':
75 ######################################################################
77 if torch.cuda.is_available():
78 print('Cuda is available')
79 device = torch.device('cuda')
80 torch.backends.cudnn.benchmark = True
82 device = torch.device('cpu')
84 ######################################################################
86 class NetToy1d(nn.Module):
87 def __init__(self, nb_classes, ks = 2, nc = 32):
89 self.pad = (ks - 1, 0)
90 self.conv0 = nn.Conv1d(1, nc, kernel_size = 1)
91 self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks)
92 self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks)
93 self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks)
94 self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks)
95 self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1)
98 x = F.relu(self.conv0(F.pad(x, (1, -1))))
99 x = F.relu(self.conv1(F.pad(x, self.pad)))
100 x = F.relu(self.conv2(F.pad(x, self.pad)))
101 x = F.relu(self.conv3(F.pad(x, self.pad)))
102 x = F.relu(self.conv4(F.pad(x, self.pad)))
104 return x.permute(0, 2, 1).contiguous()
106 class NetToy1dWithDilation(nn.Module):
107 def __init__(self, nb_classes, ks = 2, nc = 32):
109 self.conv0 = nn.Conv1d(1, nc, kernel_size = 1)
110 self.pad1 = ((ks-1) * 2, 0)
111 self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 2)
112 self.pad2 = ((ks-1) * 4, 0)
113 self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 4)
114 self.pad3 = ((ks-1) * 8, 0)
115 self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 8)
116 self.pad4 = ((ks-1) * 16, 0)
117 self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 16)
118 self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1)
120 def forward(self, x):
121 x = F.relu(self.conv0(F.pad(x, (1, -1))))
122 x = F.relu(self.conv1(F.pad(x, self.pad2)))
123 x = F.relu(self.conv2(F.pad(x, self.pad3)))
124 x = F.relu(self.conv3(F.pad(x, self.pad4)))
125 x = F.relu(self.conv4(F.pad(x, self.pad5)))
127 return x.permute(0, 2, 1).contiguous()
129 ######################################################################
131 class PixelCNN(nn.Module):
132 def __init__(self, nb_classes, in_channels = 1, ks = 5):
135 self.hpad = (ks//2, ks//2, ks//2, 0)
136 self.vpad = (ks//2, 0, 0, 0)
138 self.conv1h = nn.Conv2d(in_channels, 32, kernel_size = (ks//2+1, ks))
139 self.conv2h = nn.Conv2d(32, 64, kernel_size = (ks//2+1, ks))
140 self.conv1v = nn.Conv2d(in_channels, 32, kernel_size = (1, ks//2+1))
141 self.conv2v = nn.Conv2d(32, 64, kernel_size = (1, ks//2+1))
142 self.final1 = nn.Conv2d(128, 128, kernel_size = 1)
143 self.final2 = nn.Conv2d(128, nb_classes, kernel_size = 1)
145 def forward(self, x):
146 xh = F.pad(x, (0, 0, 1, -1))
147 xv = F.pad(x, (1, -1, 0, 0))
148 xh = F.relu(self.conv1h(F.pad(xh, self.hpad)))
149 xv = F.relu(self.conv1v(F.pad(xv, self.vpad)))
150 xh = F.relu(self.conv2h(F.pad(xh, self.hpad)))
151 xv = F.relu(self.conv2v(F.pad(xv, self.vpad)))
152 x = F.relu(self.final1(torch.cat((xh, xv), 1)))
155 return x.permute(0, 2, 3, 1).contiguous()
157 ######################################################################
159 def positional_tensor(height, width):
160 index_h = torch.arange(height).view(1, -1)
161 m_h = (2 ** torch.arange(math.ceil(math.log2(height)))).view(-1, 1)
162 b_h = (index_h // m_h) % 2
163 i_h = b_h[None, :, None, :].expand(-1, -1, height, -1)
165 index_w = torch.arange(width).view(1, -1)
166 m_w = (2 ** torch.arange(math.ceil(math.log2(width)))).view(-1, 1)
167 b_w = (index_w // m_w) % 2
168 i_w = b_w[None, :, :, None].expand(-1, -1, -1, width)
170 return torch.cat((i_w, i_h), 1)
172 ######################################################################
174 str_experiment = args.data
177 str_experiment += '-positional'
180 str_experiment += '-dilation'
182 log_file = open('causalar-' + str_experiment + '-train.log', 'w')
185 s = time.strftime("%Y%m%d-%H:%M:%S", time.localtime()) + ' ' + s
187 log_file.write(s + '\n')
190 ######################################################################
192 def generate_sequences(nb, len):
195 r = torch.empty(nb, len)
197 x = torch.empty(nb, nb_parts).uniform_(-1, 1)
198 x = x.view(nb, nb_parts, 1).expand(nb, nb_parts, len)
199 x = x * torch.linspace(0, len-1, len).view(1, -1) + len
202 a = torch.randperm(len - 2)[:nb_parts+1].sort()[0]
204 a[a.size(0) - 1] = len
205 for k in range(a.size(0) - 1):
206 r[n, a[k]:a[k+1]] = x[n, k, :a[k+1]-a[k]]
208 return r.round().long()
210 ######################################################################
212 if args.data == 'toy1d':
214 train_input = generate_sequences(50000, len).to(device).unsqueeze(1)
216 model = NetToy1dWithDilation(nb_classes = 2 * len).to(device)
218 model = NetToy1d(nb_classes = 2 * len).to(device)
220 elif args.data == 'mnist':
221 train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
222 train_input = train_set.data.view(-1, 1, 28, 28).long().to(device)
224 model = PixelCNN(nb_classes = 256, in_channels = 1).to(device)
225 in_channels = train_input.size(1)
228 height, width = train_input.size(2), train_input.size(3)
229 positional_input = positional_tensor(height, width).float().to(device)
230 in_channels += positional_input.size(1)
232 model = PixelCNN(nb_classes = 256, in_channels = in_channels).to(device)
235 raise ValueError('Unknown data ' + args.data)
237 ######################################################################
239 mean, std = train_input.float().mean(), train_input.float().std()
241 nb_parameters = sum(t.numel() for t in model.parameters())
242 log_string(f'nb_parameters {nb_parameters}')
244 cross_entropy = nn.CrossEntropyLoss().to(device)
245 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
247 for e in range(args.nb_epochs):
249 nb_batches, acc_loss = 0, 0.0
251 for sequences in train_input.split(args.batch_size):
252 input = (sequences - mean)/std
256 (input, positional_input.expand(input.size(0), -1, -1, -1)),
260 output = model(input)
262 loss = cross_entropy(
263 output.view(-1, output.size(-1)),
267 optimizer.zero_grad()
272 acc_loss += loss.item()
274 log_string(f'{e} {acc_loss / nb_batches} {math.exp(acc_loss / nb_batches)}')
278 ######################################################################
280 generated = train_input.new_zeros((48,) + train_input.size()[1:])
282 flat = generated.view(generated.size(0), -1)
284 for t in range(flat.size(1)):
285 input = (generated.float() - mean) / std
287 input = torch.cat((input, positional_input.expand(input.size(0), -1, -1, -1)), 1)
288 output = model(input)
289 logits = output.view(flat.size() + (-1,))[:, t]
290 dist = torch.distributions.categorical.Categorical(logits = logits)
291 flat[:, t] = dist.sample()
293 ######################################################################
295 if args.data == 'toy1d':
297 with open('causalar-' + str_experiment + '-train.dat', 'w') as file:
298 for j in range(train_input.size(2)):
300 for i in range(min(train_input.size(0), 25)):
301 file.write(f' {train_input[i, 0, j]}')
304 with open('causalar-' + str_experiment + '-generated.dat', 'w') as file:
305 for j in range(generated.size(2)):
307 for i in range(generated.size(0)):
308 file.write(f' {generated[i, 0, j]}')
311 elif args.data == 'mnist':
313 img_train = 1 - train_input[:generated.size(0)].float() / 255
314 img_generated = 1 - generated.float() / 255
316 save_images(img_train, 'causalar-' + str_experiment + '-train.png', nrow = 12)
317 save_images(img_generated, 'causalar-' + str_experiment + '-generated.png', nrow = 12)
319 ######################################################################