Initial commit.
authorFrancois Fleuret <francois@fleuret.org>
Sat, 29 Feb 2020 13:25:01 +0000 (14:25 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Sat, 29 Feb 2020 13:25:01 +0000 (14:25 +0100)
causal-autoregression.py [new file with mode: 0755]

diff --git a/causal-autoregression.py b/causal-autoregression.py
new file mode 100755 (executable)
index 0000000..c2f6161
--- /dev/null
@@ -0,0 +1,319 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+# ./causal-autoregression.py --data=toy1d
+# ./causal-autoregression.py --data=toy1d --dilation
+# ./causal-autoregression.py --data=mnist
+# ./causal-autoregression.py --data=mnist --positional
+
+import argparse, math, sys, time
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+def save_images(x, filename, nrow = 12):
+    print(f'Writing {filename}')
+    torchvision.utils.save_image(x.narrow(0,0, min(48, x.size(0))),
+                                 filename,
+                                 nrow = nrow, pad_value=1.0)
+
+######################################################################
+
+parser = argparse.ArgumentParser(
+    description = 'An implementation of a causal autoregression model',
+    formatter_class = argparse.ArgumentDefaultsHelpFormatter
+)
+
+parser.add_argument('--data',
+                    type = str, default = 'toy1d',
+                    help = 'What data')
+
+parser.add_argument('--seed',
+                    type = int, default = 0,
+                    help = 'Random seed (default 0, < 0 is no seeding)')
+
+parser.add_argument('--nb_epochs',
+                    type = int, default = -1,
+                    help = 'How many epochs')
+
+parser.add_argument('--batch_size',
+                    type = int, default = 100,
+                    help = 'Batch size')
+
+parser.add_argument('--learning_rate',
+                    type = float, default = 1e-3,
+                    help = 'Batch size')
+
+parser.add_argument('--positional',
+                    action='store_true', default = False,
+                    help = 'Do we provide a positional encoding as input')
+
+parser.add_argument('--dilation',
+                    action='store_true', default = False,
+                    help = 'Do we provide a positional encoding as input')
+
+######################################################################
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+    torch.manual_seed(args.seed)
+
+if args.nb_epochs < 0:
+    if args.data == 'toy1d':
+        args.nb_epochs = 100
+    elif args.data == 'mnist':
+        args.nb_epochs = 25
+
+######################################################################
+
+if torch.cuda.is_available():
+    print('Cuda is available')
+    device = torch.device('cuda')
+    torch.backends.cudnn.benchmark = True
+else:
+    device = torch.device('cpu')
+
+######################################################################
+
+class NetToy1d(nn.Module):
+    def __init__(self, nb_classes, ks = 2, nc = 32):
+        super(NetToy1d, self).__init__()
+        self.pad = (ks - 1, 0)
+        self.conv0 = nn.Conv1d(1, nc, kernel_size = 1)
+        self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks)
+        self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks)
+        self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks)
+        self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks)
+        self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1)
+
+    def forward(self, x):
+        x = F.relu(self.conv0(F.pad(x, (1, -1))))
+        x = F.relu(self.conv1(F.pad(x, self.pad)))
+        x = F.relu(self.conv2(F.pad(x, self.pad)))
+        x = F.relu(self.conv3(F.pad(x, self.pad)))
+        x = F.relu(self.conv4(F.pad(x, self.pad)))
+        x = self.conv5(x)
+        return x.permute(0, 2, 1).contiguous()
+
+class NetToy1dWithDilation(nn.Module):
+    def __init__(self, nb_classes, ks = 2, nc = 32):
+        super(NetToy1dWithDilation, self).__init__()
+        self.conv0 = nn.Conv1d(1, nc, kernel_size = 1)
+        self.pad1 = ((ks-1) * 2, 0)
+        self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 2)
+        self.pad2 = ((ks-1) * 4, 0)
+        self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 4)
+        self.pad3 = ((ks-1) * 8, 0)
+        self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 8)
+        self.pad4 = ((ks-1) * 16, 0)
+        self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 16)
+        self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1)
+
+    def forward(self, x):
+        x = F.relu(self.conv0(F.pad(x, (1, -1))))
+        x = F.relu(self.conv1(F.pad(x, self.pad2)))
+        x = F.relu(self.conv2(F.pad(x, self.pad3)))
+        x = F.relu(self.conv3(F.pad(x, self.pad4)))
+        x = F.relu(self.conv4(F.pad(x, self.pad5)))
+        x = self.conv5(x)
+        return x.permute(0, 2, 1).contiguous()
+
+######################################################################
+
+class PixelCNN(nn.Module):
+    def __init__(self, nb_classes, in_channels = 1, ks = 5):
+        super(PixelCNN, self).__init__()
+
+        self.hpad = (ks//2, ks//2, ks//2, 0)
+        self.vpad = (ks//2,     0,     0, 0)
+
+        self.conv1h = nn.Conv2d(in_channels, 32, kernel_size = (ks//2+1, ks))
+        self.conv2h = nn.Conv2d(32, 64, kernel_size = (ks//2+1, ks))
+        self.conv1v = nn.Conv2d(in_channels, 32, kernel_size = (1, ks//2+1))
+        self.conv2v = nn.Conv2d(32, 64, kernel_size = (1, ks//2+1))
+        self.final1 = nn.Conv2d(128, 128, kernel_size = 1)
+        self.final2 = nn.Conv2d(128, nb_classes, kernel_size = 1)
+
+    def forward(self, x):
+        xh = F.pad(x, (0, 0, 1, -1))
+        xv = F.pad(x, (1, -1, 0, 0))
+        xh = F.relu(self.conv1h(F.pad(xh, self.hpad)))
+        xv = F.relu(self.conv1v(F.pad(xv, self.vpad)))
+        xh = F.relu(self.conv2h(F.pad(xh, self.hpad)))
+        xv = F.relu(self.conv2v(F.pad(xv, self.vpad)))
+        x = F.relu(self.final1(torch.cat((xh, xv), 1)))
+        x = self.final2(x)
+
+        return x.permute(0, 2, 3, 1).contiguous()
+
+######################################################################
+
+def positional_tensor(height, width):
+    index_h = torch.arange(height).view(1, -1)
+    m_h = (2 ** torch.arange(math.ceil(math.log2(height)))).view(-1, 1)
+    b_h = (index_h // m_h) % 2
+    i_h = b_h[None, :, None, :].expand(-1, -1, height, -1)
+
+    index_w = torch.arange(width).view(1, -1)
+    m_w = (2 ** torch.arange(math.ceil(math.log2(width)))).view(-1, 1)
+    b_w = (index_w // m_w) % 2
+    i_w = b_w[None, :, :, None].expand(-1, -1, -1, width)
+
+    return torch.cat((i_w, i_h), 1)
+
+######################################################################
+
+str_experiment = args.data
+
+if args.positional:
+    str_experiment += '-positional'
+
+if args.dilation:
+    str_experiment += '-dilation'
+
+log_file = open('causalar-' + str_experiment + '-train.log', 'w')
+
+def log_string(s):
+    s = time.strftime("%Y%m%d-%H:%M:%S", time.localtime()) + ' ' + s
+    print(s)
+    log_file.write(s + '\n')
+    log_file.flush()
+
+######################################################################
+
+def generate_sequences(nb, len):
+    nb_parts = 2
+
+    r = torch.empty(nb, len)
+
+    x = torch.empty(nb, nb_parts).uniform_(-1, 1)
+    x = x.view(nb, nb_parts, 1).expand(nb, nb_parts, len)
+    x = x * torch.linspace(0, len-1, len).view(1, -1) + len
+
+    for n in range(nb):
+        a = torch.randperm(len - 2)[:nb_parts+1].sort()[0]
+        a[0] = 0
+        a[a.size(0) - 1] = len
+        for k in range(a.size(0) - 1):
+            r[n, a[k]:a[k+1]] = x[n, k, :a[k+1]-a[k]]
+
+    return r.round().long()
+
+######################################################################
+
+if args.data == 'toy1d':
+    len = 32
+    train_input = generate_sequences(50000, len).to(device).unsqueeze(1)
+    if args.dilation:
+        model = NetToy1dWithDilation(nb_classes = 2 * len).to(device)
+    else:
+        model = NetToy1d(nb_classes = 2 * len).to(device)
+
+elif args.data == 'mnist':
+    train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
+    train_input = train_set.data.view(-1, 1, 28, 28).long().to(device)
+
+    model = PixelCNN(nb_classes = 256, in_channels = 1).to(device)
+    in_channels = train_input.size(1)
+
+    if args.positional:
+        height, width = train_input.size(2), train_input.size(3)
+        positional_input = positional_tensor(height, width).float().to(device)
+        in_channels += positional_input.size(1)
+
+    model = PixelCNN(nb_classes = 256, in_channels = in_channels).to(device)
+
+else:
+    raise ValueError('Unknown data ' + args.data)
+
+######################################################################
+
+mean, std = train_input.float().mean(), train_input.float().std()
+
+nb_parameters = sum(t.numel() for t in model.parameters())
+log_string(f'nb_parameters {nb_parameters}')
+
+cross_entropy = nn.CrossEntropyLoss().to(device)
+optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
+
+for e in range(args.nb_epochs):
+
+    nb_batches, acc_loss = 0, 0.0
+
+    for sequences in train_input.split(args.batch_size):
+        input = (sequences - mean)/std
+
+        if args.positional:
+            input = torch.cat(
+                (input, positional_input.expand(input.size(0), -1, -1, -1)),
+                1
+            )
+
+        output = model(input)
+
+        loss = cross_entropy(
+            output.view(-1, output.size(-1)),
+            sequences.view(-1)
+        )
+
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+        nb_batches += 1
+        acc_loss += loss.item()
+
+    log_string(f'{e} {acc_loss / nb_batches} {math.exp(acc_loss / nb_batches)}')
+
+    sys.stdout.flush()
+
+######################################################################
+
+generated = train_input.new_zeros((48,) + train_input.size()[1:])
+
+flat = generated.view(generated.size(0), -1)
+
+for t in range(flat.size(1)):
+    input = (generated.float() - mean) / std
+    if args.positional:
+        input = torch.cat((input, positional_input.expand(input.size(0), -1, -1, -1)), 1)
+    output = model(input)
+    logits = output.view(flat.size() + (-1,))[:, t]
+    dist = torch.distributions.categorical.Categorical(logits = logits)
+    flat[:, t] = dist.sample()
+
+######################################################################
+
+if args.data == 'toy1d':
+
+    with open('causalar-' + str_experiment + '-train.dat', 'w') as file:
+        for j in range(train_input.size(2)):
+            file.write(f'{j}')
+            for i in range(min(train_input.size(0), 25)):
+                file.write(f' {train_input[i, 0, j]}')
+            file.write('\n')
+
+    with open('causalar-' + str_experiment + '-generated.dat', 'w') as file:
+        for j in range(generated.size(2)):
+            file.write(f'{j}')
+            for i in range(generated.size(0)):
+                file.write(f' {generated[i, 0, j]}')
+            file.write('\n')
+
+elif args.data == 'mnist':
+
+    img_train = 1 - train_input[:generated.size(0)].float() / 255
+    img_generated = 1 - generated.float() / 255
+
+    save_images(img_train, 'causalar-' + str_experiment + '-train.png', nrow = 12)
+    save_images(img_generated, 'causalar-' + str_experiment + '-generated.png', nrow = 12)
+
+######################################################################