--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+# ./causal-autoregression.py --data=toy1d
+# ./causal-autoregression.py --data=toy1d --dilation
+# ./causal-autoregression.py --data=mnist
+# ./causal-autoregression.py --data=mnist --positional
+
+import argparse, math, sys, time
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+def save_images(x, filename, nrow = 12):
+ print(f'Writing {filename}')
+ torchvision.utils.save_image(x.narrow(0,0, min(48, x.size(0))),
+ filename,
+ nrow = nrow, pad_value=1.0)
+
+######################################################################
+
+parser = argparse.ArgumentParser(
+ description = 'An implementation of a causal autoregression model',
+ formatter_class = argparse.ArgumentDefaultsHelpFormatter
+)
+
+parser.add_argument('--data',
+ type = str, default = 'toy1d',
+ help = 'What data')
+
+parser.add_argument('--seed',
+ type = int, default = 0,
+ help = 'Random seed (default 0, < 0 is no seeding)')
+
+parser.add_argument('--nb_epochs',
+ type = int, default = -1,
+ help = 'How many epochs')
+
+parser.add_argument('--batch_size',
+ type = int, default = 100,
+ help = 'Batch size')
+
+parser.add_argument('--learning_rate',
+ type = float, default = 1e-3,
+ help = 'Batch size')
+
+parser.add_argument('--positional',
+ action='store_true', default = False,
+ help = 'Do we provide a positional encoding as input')
+
+parser.add_argument('--dilation',
+ action='store_true', default = False,
+ help = 'Do we provide a positional encoding as input')
+
+######################################################################
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+ torch.manual_seed(args.seed)
+
+if args.nb_epochs < 0:
+ if args.data == 'toy1d':
+ args.nb_epochs = 100
+ elif args.data == 'mnist':
+ args.nb_epochs = 25
+
+######################################################################
+
+if torch.cuda.is_available():
+ print('Cuda is available')
+ device = torch.device('cuda')
+ torch.backends.cudnn.benchmark = True
+else:
+ device = torch.device('cpu')
+
+######################################################################
+
+class NetToy1d(nn.Module):
+ def __init__(self, nb_classes, ks = 2, nc = 32):
+ super(NetToy1d, self).__init__()
+ self.pad = (ks - 1, 0)
+ self.conv0 = nn.Conv1d(1, nc, kernel_size = 1)
+ self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks)
+ self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks)
+ self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks)
+ self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks)
+ self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1)
+
+ def forward(self, x):
+ x = F.relu(self.conv0(F.pad(x, (1, -1))))
+ x = F.relu(self.conv1(F.pad(x, self.pad)))
+ x = F.relu(self.conv2(F.pad(x, self.pad)))
+ x = F.relu(self.conv3(F.pad(x, self.pad)))
+ x = F.relu(self.conv4(F.pad(x, self.pad)))
+ x = self.conv5(x)
+ return x.permute(0, 2, 1).contiguous()
+
+class NetToy1dWithDilation(nn.Module):
+ def __init__(self, nb_classes, ks = 2, nc = 32):
+ super(NetToy1dWithDilation, self).__init__()
+ self.conv0 = nn.Conv1d(1, nc, kernel_size = 1)
+ self.pad1 = ((ks-1) * 2, 0)
+ self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 2)
+ self.pad2 = ((ks-1) * 4, 0)
+ self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 4)
+ self.pad3 = ((ks-1) * 8, 0)
+ self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 8)
+ self.pad4 = ((ks-1) * 16, 0)
+ self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 16)
+ self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1)
+
+ def forward(self, x):
+ x = F.relu(self.conv0(F.pad(x, (1, -1))))
+ x = F.relu(self.conv1(F.pad(x, self.pad2)))
+ x = F.relu(self.conv2(F.pad(x, self.pad3)))
+ x = F.relu(self.conv3(F.pad(x, self.pad4)))
+ x = F.relu(self.conv4(F.pad(x, self.pad5)))
+ x = self.conv5(x)
+ return x.permute(0, 2, 1).contiguous()
+
+######################################################################
+
+class PixelCNN(nn.Module):
+ def __init__(self, nb_classes, in_channels = 1, ks = 5):
+ super(PixelCNN, self).__init__()
+
+ self.hpad = (ks//2, ks//2, ks//2, 0)
+ self.vpad = (ks//2, 0, 0, 0)
+
+ self.conv1h = nn.Conv2d(in_channels, 32, kernel_size = (ks//2+1, ks))
+ self.conv2h = nn.Conv2d(32, 64, kernel_size = (ks//2+1, ks))
+ self.conv1v = nn.Conv2d(in_channels, 32, kernel_size = (1, ks//2+1))
+ self.conv2v = nn.Conv2d(32, 64, kernel_size = (1, ks//2+1))
+ self.final1 = nn.Conv2d(128, 128, kernel_size = 1)
+ self.final2 = nn.Conv2d(128, nb_classes, kernel_size = 1)
+
+ def forward(self, x):
+ xh = F.pad(x, (0, 0, 1, -1))
+ xv = F.pad(x, (1, -1, 0, 0))
+ xh = F.relu(self.conv1h(F.pad(xh, self.hpad)))
+ xv = F.relu(self.conv1v(F.pad(xv, self.vpad)))
+ xh = F.relu(self.conv2h(F.pad(xh, self.hpad)))
+ xv = F.relu(self.conv2v(F.pad(xv, self.vpad)))
+ x = F.relu(self.final1(torch.cat((xh, xv), 1)))
+ x = self.final2(x)
+
+ return x.permute(0, 2, 3, 1).contiguous()
+
+######################################################################
+
+def positional_tensor(height, width):
+ index_h = torch.arange(height).view(1, -1)
+ m_h = (2 ** torch.arange(math.ceil(math.log2(height)))).view(-1, 1)
+ b_h = (index_h // m_h) % 2
+ i_h = b_h[None, :, None, :].expand(-1, -1, height, -1)
+
+ index_w = torch.arange(width).view(1, -1)
+ m_w = (2 ** torch.arange(math.ceil(math.log2(width)))).view(-1, 1)
+ b_w = (index_w // m_w) % 2
+ i_w = b_w[None, :, :, None].expand(-1, -1, -1, width)
+
+ return torch.cat((i_w, i_h), 1)
+
+######################################################################
+
+str_experiment = args.data
+
+if args.positional:
+ str_experiment += '-positional'
+
+if args.dilation:
+ str_experiment += '-dilation'
+
+log_file = open('causalar-' + str_experiment + '-train.log', 'w')
+
+def log_string(s):
+ s = time.strftime("%Y%m%d-%H:%M:%S", time.localtime()) + ' ' + s
+ print(s)
+ log_file.write(s + '\n')
+ log_file.flush()
+
+######################################################################
+
+def generate_sequences(nb, len):
+ nb_parts = 2
+
+ r = torch.empty(nb, len)
+
+ x = torch.empty(nb, nb_parts).uniform_(-1, 1)
+ x = x.view(nb, nb_parts, 1).expand(nb, nb_parts, len)
+ x = x * torch.linspace(0, len-1, len).view(1, -1) + len
+
+ for n in range(nb):
+ a = torch.randperm(len - 2)[:nb_parts+1].sort()[0]
+ a[0] = 0
+ a[a.size(0) - 1] = len
+ for k in range(a.size(0) - 1):
+ r[n, a[k]:a[k+1]] = x[n, k, :a[k+1]-a[k]]
+
+ return r.round().long()
+
+######################################################################
+
+if args.data == 'toy1d':
+ len = 32
+ train_input = generate_sequences(50000, len).to(device).unsqueeze(1)
+ if args.dilation:
+ model = NetToy1dWithDilation(nb_classes = 2 * len).to(device)
+ else:
+ model = NetToy1d(nb_classes = 2 * len).to(device)
+
+elif args.data == 'mnist':
+ train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
+ train_input = train_set.data.view(-1, 1, 28, 28).long().to(device)
+
+ model = PixelCNN(nb_classes = 256, in_channels = 1).to(device)
+ in_channels = train_input.size(1)
+
+ if args.positional:
+ height, width = train_input.size(2), train_input.size(3)
+ positional_input = positional_tensor(height, width).float().to(device)
+ in_channels += positional_input.size(1)
+
+ model = PixelCNN(nb_classes = 256, in_channels = in_channels).to(device)
+
+else:
+ raise ValueError('Unknown data ' + args.data)
+
+######################################################################
+
+mean, std = train_input.float().mean(), train_input.float().std()
+
+nb_parameters = sum(t.numel() for t in model.parameters())
+log_string(f'nb_parameters {nb_parameters}')
+
+cross_entropy = nn.CrossEntropyLoss().to(device)
+optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
+
+for e in range(args.nb_epochs):
+
+ nb_batches, acc_loss = 0, 0.0
+
+ for sequences in train_input.split(args.batch_size):
+ input = (sequences - mean)/std
+
+ if args.positional:
+ input = torch.cat(
+ (input, positional_input.expand(input.size(0), -1, -1, -1)),
+ 1
+ )
+
+ output = model(input)
+
+ loss = cross_entropy(
+ output.view(-1, output.size(-1)),
+ sequences.view(-1)
+ )
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ nb_batches += 1
+ acc_loss += loss.item()
+
+ log_string(f'{e} {acc_loss / nb_batches} {math.exp(acc_loss / nb_batches)}')
+
+ sys.stdout.flush()
+
+######################################################################
+
+generated = train_input.new_zeros((48,) + train_input.size()[1:])
+
+flat = generated.view(generated.size(0), -1)
+
+for t in range(flat.size(1)):
+ input = (generated.float() - mean) / std
+ if args.positional:
+ input = torch.cat((input, positional_input.expand(input.size(0), -1, -1, -1)), 1)
+ output = model(input)
+ logits = output.view(flat.size() + (-1,))[:, t]
+ dist = torch.distributions.categorical.Categorical(logits = logits)
+ flat[:, t] = dist.sample()
+
+######################################################################
+
+if args.data == 'toy1d':
+
+ with open('causalar-' + str_experiment + '-train.dat', 'w') as file:
+ for j in range(train_input.size(2)):
+ file.write(f'{j}')
+ for i in range(min(train_input.size(0), 25)):
+ file.write(f' {train_input[i, 0, j]}')
+ file.write('\n')
+
+ with open('causalar-' + str_experiment + '-generated.dat', 'w') as file:
+ for j in range(generated.size(2)):
+ file.write(f'{j}')
+ for i in range(generated.size(0)):
+ file.write(f' {generated[i, 0, j]}')
+ file.write('\n')
+
+elif args.data == 'mnist':
+
+ img_train = 1 - train_input[:generated.size(0)].float() / 255
+ img_generated = 1 - generated.float() / 255
+
+ save_images(img_train, 'causalar-' + str_experiment + '-train.png', nrow = 12)
+ save_images(img_generated, 'causalar-' + str_experiment + '-generated.png', nrow = 12)
+
+######################################################################