From 762a2c5e2485e0ebd7c26fe980893a4de2544bb9 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Sat, 29 Feb 2020 14:25:01 +0100 Subject: [PATCH] Initial commit. --- causal-autoregression.py | 319 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 319 insertions(+) create mode 100755 causal-autoregression.py diff --git a/causal-autoregression.py b/causal-autoregression.py new file mode 100755 index 0000000..c2f6161 --- /dev/null +++ b/causal-autoregression.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +# ./causal-autoregression.py --data=toy1d +# ./causal-autoregression.py --data=toy1d --dilation +# ./causal-autoregression.py --data=mnist +# ./causal-autoregression.py --data=mnist --positional + +import argparse, math, sys, time +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + +def save_images(x, filename, nrow = 12): + print(f'Writing {filename}') + torchvision.utils.save_image(x.narrow(0,0, min(48, x.size(0))), + filename, + nrow = nrow, pad_value=1.0) + +###################################################################### + +parser = argparse.ArgumentParser( + description = 'An implementation of a causal autoregression model', + formatter_class = argparse.ArgumentDefaultsHelpFormatter +) + +parser.add_argument('--data', + type = str, default = 'toy1d', + help = 'What data') + +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed (default 0, < 0 is no seeding)') + +parser.add_argument('--nb_epochs', + type = int, default = -1, + help = 'How many epochs') + +parser.add_argument('--batch_size', + type = int, default = 100, + help = 'Batch size') + +parser.add_argument('--learning_rate', + type = float, default = 1e-3, + help = 'Batch size') + +parser.add_argument('--positional', + action='store_true', default = False, + help = 'Do we provide a positional encoding as input') + +parser.add_argument('--dilation', + action='store_true', default = False, + help = 'Do we provide a positional encoding as input') + +###################################################################### + +args = parser.parse_args() + +if args.seed >= 0: + torch.manual_seed(args.seed) + +if args.nb_epochs < 0: + if args.data == 'toy1d': + args.nb_epochs = 100 + elif args.data == 'mnist': + args.nb_epochs = 25 + +###################################################################### + +if torch.cuda.is_available(): + print('Cuda is available') + device = torch.device('cuda') + torch.backends.cudnn.benchmark = True +else: + device = torch.device('cpu') + +###################################################################### + +class NetToy1d(nn.Module): + def __init__(self, nb_classes, ks = 2, nc = 32): + super(NetToy1d, self).__init__() + self.pad = (ks - 1, 0) + self.conv0 = nn.Conv1d(1, nc, kernel_size = 1) + self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks) + self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks) + self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks) + self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks) + self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1) + + def forward(self, x): + x = F.relu(self.conv0(F.pad(x, (1, -1)))) + x = F.relu(self.conv1(F.pad(x, self.pad))) + x = F.relu(self.conv2(F.pad(x, self.pad))) + x = F.relu(self.conv3(F.pad(x, self.pad))) + x = F.relu(self.conv4(F.pad(x, self.pad))) + x = self.conv5(x) + return x.permute(0, 2, 1).contiguous() + +class NetToy1dWithDilation(nn.Module): + def __init__(self, nb_classes, ks = 2, nc = 32): + super(NetToy1dWithDilation, self).__init__() + self.conv0 = nn.Conv1d(1, nc, kernel_size = 1) + self.pad1 = ((ks-1) * 2, 0) + self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 2) + self.pad2 = ((ks-1) * 4, 0) + self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 4) + self.pad3 = ((ks-1) * 8, 0) + self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 8) + self.pad4 = ((ks-1) * 16, 0) + self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 16) + self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1) + + def forward(self, x): + x = F.relu(self.conv0(F.pad(x, (1, -1)))) + x = F.relu(self.conv1(F.pad(x, self.pad2))) + x = F.relu(self.conv2(F.pad(x, self.pad3))) + x = F.relu(self.conv3(F.pad(x, self.pad4))) + x = F.relu(self.conv4(F.pad(x, self.pad5))) + x = self.conv5(x) + return x.permute(0, 2, 1).contiguous() + +###################################################################### + +class PixelCNN(nn.Module): + def __init__(self, nb_classes, in_channels = 1, ks = 5): + super(PixelCNN, self).__init__() + + self.hpad = (ks//2, ks//2, ks//2, 0) + self.vpad = (ks//2, 0, 0, 0) + + self.conv1h = nn.Conv2d(in_channels, 32, kernel_size = (ks//2+1, ks)) + self.conv2h = nn.Conv2d(32, 64, kernel_size = (ks//2+1, ks)) + self.conv1v = nn.Conv2d(in_channels, 32, kernel_size = (1, ks//2+1)) + self.conv2v = nn.Conv2d(32, 64, kernel_size = (1, ks//2+1)) + self.final1 = nn.Conv2d(128, 128, kernel_size = 1) + self.final2 = nn.Conv2d(128, nb_classes, kernel_size = 1) + + def forward(self, x): + xh = F.pad(x, (0, 0, 1, -1)) + xv = F.pad(x, (1, -1, 0, 0)) + xh = F.relu(self.conv1h(F.pad(xh, self.hpad))) + xv = F.relu(self.conv1v(F.pad(xv, self.vpad))) + xh = F.relu(self.conv2h(F.pad(xh, self.hpad))) + xv = F.relu(self.conv2v(F.pad(xv, self.vpad))) + x = F.relu(self.final1(torch.cat((xh, xv), 1))) + x = self.final2(x) + + return x.permute(0, 2, 3, 1).contiguous() + +###################################################################### + +def positional_tensor(height, width): + index_h = torch.arange(height).view(1, -1) + m_h = (2 ** torch.arange(math.ceil(math.log2(height)))).view(-1, 1) + b_h = (index_h // m_h) % 2 + i_h = b_h[None, :, None, :].expand(-1, -1, height, -1) + + index_w = torch.arange(width).view(1, -1) + m_w = (2 ** torch.arange(math.ceil(math.log2(width)))).view(-1, 1) + b_w = (index_w // m_w) % 2 + i_w = b_w[None, :, :, None].expand(-1, -1, -1, width) + + return torch.cat((i_w, i_h), 1) + +###################################################################### + +str_experiment = args.data + +if args.positional: + str_experiment += '-positional' + +if args.dilation: + str_experiment += '-dilation' + +log_file = open('causalar-' + str_experiment + '-train.log', 'w') + +def log_string(s): + s = time.strftime("%Y%m%d-%H:%M:%S", time.localtime()) + ' ' + s + print(s) + log_file.write(s + '\n') + log_file.flush() + +###################################################################### + +def generate_sequences(nb, len): + nb_parts = 2 + + r = torch.empty(nb, len) + + x = torch.empty(nb, nb_parts).uniform_(-1, 1) + x = x.view(nb, nb_parts, 1).expand(nb, nb_parts, len) + x = x * torch.linspace(0, len-1, len).view(1, -1) + len + + for n in range(nb): + a = torch.randperm(len - 2)[:nb_parts+1].sort()[0] + a[0] = 0 + a[a.size(0) - 1] = len + for k in range(a.size(0) - 1): + r[n, a[k]:a[k+1]] = x[n, k, :a[k+1]-a[k]] + + return r.round().long() + +###################################################################### + +if args.data == 'toy1d': + len = 32 + train_input = generate_sequences(50000, len).to(device).unsqueeze(1) + if args.dilation: + model = NetToy1dWithDilation(nb_classes = 2 * len).to(device) + else: + model = NetToy1d(nb_classes = 2 * len).to(device) + +elif args.data == 'mnist': + train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True) + train_input = train_set.data.view(-1, 1, 28, 28).long().to(device) + + model = PixelCNN(nb_classes = 256, in_channels = 1).to(device) + in_channels = train_input.size(1) + + if args.positional: + height, width = train_input.size(2), train_input.size(3) + positional_input = positional_tensor(height, width).float().to(device) + in_channels += positional_input.size(1) + + model = PixelCNN(nb_classes = 256, in_channels = in_channels).to(device) + +else: + raise ValueError('Unknown data ' + args.data) + +###################################################################### + +mean, std = train_input.float().mean(), train_input.float().std() + +nb_parameters = sum(t.numel() for t in model.parameters()) +log_string(f'nb_parameters {nb_parameters}') + +cross_entropy = nn.CrossEntropyLoss().to(device) +optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) + +for e in range(args.nb_epochs): + + nb_batches, acc_loss = 0, 0.0 + + for sequences in train_input.split(args.batch_size): + input = (sequences - mean)/std + + if args.positional: + input = torch.cat( + (input, positional_input.expand(input.size(0), -1, -1, -1)), + 1 + ) + + output = model(input) + + loss = cross_entropy( + output.view(-1, output.size(-1)), + sequences.view(-1) + ) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + nb_batches += 1 + acc_loss += loss.item() + + log_string(f'{e} {acc_loss / nb_batches} {math.exp(acc_loss / nb_batches)}') + + sys.stdout.flush() + +###################################################################### + +generated = train_input.new_zeros((48,) + train_input.size()[1:]) + +flat = generated.view(generated.size(0), -1) + +for t in range(flat.size(1)): + input = (generated.float() - mean) / std + if args.positional: + input = torch.cat((input, positional_input.expand(input.size(0), -1, -1, -1)), 1) + output = model(input) + logits = output.view(flat.size() + (-1,))[:, t] + dist = torch.distributions.categorical.Categorical(logits = logits) + flat[:, t] = dist.sample() + +###################################################################### + +if args.data == 'toy1d': + + with open('causalar-' + str_experiment + '-train.dat', 'w') as file: + for j in range(train_input.size(2)): + file.write(f'{j}') + for i in range(min(train_input.size(0), 25)): + file.write(f' {train_input[i, 0, j]}') + file.write('\n') + + with open('causalar-' + str_experiment + '-generated.dat', 'w') as file: + for j in range(generated.size(2)): + file.write(f'{j}') + for i in range(generated.size(0)): + file.write(f' {generated[i, 0, j]}') + file.write('\n') + +elif args.data == 'mnist': + + img_train = 1 - train_input[:generated.size(0)].float() / 255 + img_generated = 1 - generated.float() / 255 + + save_images(img_train, 'causalar-' + str_experiment + '-train.png', nrow = 12) + save_images(img_generated, 'causalar-' + str_experiment + '-generated.png', nrow = 12) + +###################################################################### -- 2.39.5