######################################################################
-def save_images(x, filename, nrow = 12):
- print(f'Writing {filename}')
- torchvision.utils.save_image(x.narrow(0,0, min(48, x.size(0))),
- filename,
- nrow = nrow, pad_value=1.0)
+
+def save_images(x, filename, nrow=12):
+ print(f"Writing {filename}")
+ torchvision.utils.save_image(
+ x.narrow(0, 0, min(48, x.size(0))), filename, nrow=nrow, pad_value=1.0
+ )
+
######################################################################
parser = argparse.ArgumentParser(
- description = 'An implementation of a causal autoregression model',
- formatter_class = argparse.ArgumentDefaultsHelpFormatter
+ description="An implementation of a causal autoregression model",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
-parser.add_argument('--data',
- type = str, default = 'toy1d',
- help = 'What data')
+parser.add_argument("--data", type=str, default="toy1d", help="What data")
-parser.add_argument('--seed',
- type = int, default = 0,
- help = 'Random seed (default 0, < 0 is no seeding)')
+parser.add_argument(
+ "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
+)
-parser.add_argument('--nb_epochs',
- type = int, default = -1,
- help = 'How many epochs')
+parser.add_argument("--nb_epochs", type=int, default=-1, help="How many epochs")
-parser.add_argument('--batch_size',
- type = int, default = 100,
- help = 'Batch size')
+parser.add_argument("--batch_size", type=int, default=100, help="Batch size")
-parser.add_argument('--learning_rate',
- type = float, default = 1e-3,
- help = 'Batch size')
+parser.add_argument("--learning_rate", type=float, default=1e-3, help="Batch size")
-parser.add_argument('--positional',
- action='store_true', default = False,
- help = 'Do we provide a positional encoding as input')
+parser.add_argument(
+ "--positional",
+ action="store_true",
+ default=False,
+ help="Do we provide a positional encoding as input",
+)
-parser.add_argument('--dilation',
- action='store_true', default = False,
- help = 'Do we provide a positional encoding as input')
+parser.add_argument(
+ "--dilation",
+ action="store_true",
+ default=False,
+ help="Do we provide a positional encoding as input",
+)
######################################################################
torch.manual_seed(args.seed)
if args.nb_epochs < 0:
- if args.data == 'toy1d':
+ if args.data == "toy1d":
args.nb_epochs = 100
- elif args.data == 'mnist':
+ elif args.data == "mnist":
args.nb_epochs = 25
######################################################################
if torch.cuda.is_available():
- print('Cuda is available')
- device = torch.device('cuda')
+ print("Cuda is available")
+ device = torch.device("cuda")
torch.backends.cudnn.benchmark = True
else:
- device = torch.device('cpu')
+ device = torch.device("cpu")
######################################################################
+
class NetToy1d(nn.Module):
- def __init__(self, nb_classes, ks = 2, nc = 32):
- super(NetToy1d, self).__init__()
+ def __init__(self, nb_classes, ks=2, nc=32):
+ super().__init__()
self.pad = (ks - 1, 0)
- self.conv0 = nn.Conv1d(1, nc, kernel_size = 1)
- self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks)
- self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks)
- self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks)
- self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks)
- self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1)
+ self.conv0 = nn.Conv1d(1, nc, kernel_size=1)
+ self.conv1 = nn.Conv1d(nc, nc, kernel_size=ks)
+ self.conv2 = nn.Conv1d(nc, nc, kernel_size=ks)
+ self.conv3 = nn.Conv1d(nc, nc, kernel_size=ks)
+ self.conv4 = nn.Conv1d(nc, nc, kernel_size=ks)
+ self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size=1)
def forward(self, x):
x = F.relu(self.conv0(F.pad(x, (1, -1))))
x = self.conv5(x)
return x.permute(0, 2, 1).contiguous()
+
class NetToy1dWithDilation(nn.Module):
- def __init__(self, nb_classes, ks = 2, nc = 32):
- super(NetToy1dWithDilation, self).__init__()
- self.conv0 = nn.Conv1d(1, nc, kernel_size = 1)
- self.pad1 = ((ks-1) * 2, 0)
- self.conv1 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 2)
- self.pad2 = ((ks-1) * 4, 0)
- self.conv2 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 4)
- self.pad3 = ((ks-1) * 8, 0)
- self.conv3 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 8)
- self.pad4 = ((ks-1) * 16, 0)
- self.conv4 = nn.Conv1d(nc, nc, kernel_size = ks, dilation = 16)
- self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size = 1)
+ def __init__(self, nb_classes, ks=2, nc=32):
+ super().__init__()
+ self.conv0 = nn.Conv1d(1, nc, kernel_size=1)
+ self.pad1 = ((ks - 1) * 2, 0)
+ self.conv1 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=2)
+ self.pad2 = ((ks - 1) * 4, 0)
+ self.conv2 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=4)
+ self.pad3 = ((ks - 1) * 8, 0)
+ self.conv3 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=8)
+ self.pad4 = ((ks - 1) * 16, 0)
+ self.conv4 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=16)
+ self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size=1)
def forward(self, x):
x = F.relu(self.conv0(F.pad(x, (1, -1))))
x = self.conv5(x)
return x.permute(0, 2, 1).contiguous()
+
######################################################################
+
class PixelCNN(nn.Module):
- def __init__(self, nb_classes, in_channels = 1, ks = 5):
- super(PixelCNN, self).__init__()
+ def __init__(self, nb_classes, in_channels=1, ks=5):
+ super().__init__()
- self.hpad = (ks//2, ks//2, ks//2, 0)
- self.vpad = (ks//2, 0, 0, 0)
+ self.hpad = (ks // 2, ks // 2, ks // 2, 0)
+ self.vpad = (ks // 2, 0, 0, 0)
- self.conv1h = nn.Conv2d(in_channels, 32, kernel_size = (ks//2+1, ks))
- self.conv2h = nn.Conv2d(32, 64, kernel_size = (ks//2+1, ks))
- self.conv1v = nn.Conv2d(in_channels, 32, kernel_size = (1, ks//2+1))
- self.conv2v = nn.Conv2d(32, 64, kernel_size = (1, ks//2+1))
- self.final1 = nn.Conv2d(128, 128, kernel_size = 1)
- self.final2 = nn.Conv2d(128, nb_classes, kernel_size = 1)
+ self.conv1h = nn.Conv2d(in_channels, 32, kernel_size=(ks // 2 + 1, ks))
+ self.conv2h = nn.Conv2d(32, 64, kernel_size=(ks // 2 + 1, ks))
+ self.conv1v = nn.Conv2d(in_channels, 32, kernel_size=(1, ks // 2 + 1))
+ self.conv2v = nn.Conv2d(32, 64, kernel_size=(1, ks // 2 + 1))
+ self.final1 = nn.Conv2d(128, 128, kernel_size=1)
+ self.final2 = nn.Conv2d(128, nb_classes, kernel_size=1)
def forward(self, x):
xh = F.pad(x, (0, 0, 1, -1))
return x.permute(0, 2, 3, 1).contiguous()
+
######################################################################
+
def positional_tensor(height, width):
index_h = torch.arange(height).view(1, -1)
m_h = (2 ** torch.arange(math.ceil(math.log2(height)))).view(-1, 1)
return torch.cat((i_w, i_h), 1)
+
######################################################################
str_experiment = args.data
if args.positional:
- str_experiment += '-positional'
+ str_experiment += "-positional"
if args.dilation:
- str_experiment += '-dilation'
+ str_experiment += "-dilation"
+
+log_file = open("causalar-" + str_experiment + "-train.log", "w")
-log_file = open('causalar-' + str_experiment + '-train.log', 'w')
def log_string(s):
- s = time.strftime("%Y%m%d-%H:%M:%S", time.localtime()) + ' ' + s
+ s = time.strftime("%Y%m%d-%H:%M:%S", time.localtime()) + " " + s
print(s)
- log_file.write(s + '\n')
+ log_file.write(s + "\n")
log_file.flush()
+
######################################################################
+
def generate_sequences(nb, len):
nb_parts = 2
x = torch.empty(nb, nb_parts).uniform_(-1, 1)
x = x.view(nb, nb_parts, 1).expand(nb, nb_parts, len)
- x = x * torch.linspace(0, len-1, len).view(1, -1) + len
+ x = x * torch.linspace(0, len - 1, len).view(1, -1) + len
for n in range(nb):
- a = torch.randperm(len - 2)[:nb_parts+1].sort()[0]
+ a = torch.randperm(len - 2)[: nb_parts + 1].sort()[0]
a[0] = 0
a[a.size(0) - 1] = len
for k in range(a.size(0) - 1):
- r[n, a[k]:a[k+1]] = x[n, k, :a[k+1]-a[k]]
+ r[n, a[k] : a[k + 1]] = x[n, k, : a[k + 1] - a[k]]
return r.round().long()
+
######################################################################
-if args.data == 'toy1d':
+if args.data == "toy1d":
len = 32
train_input = generate_sequences(50000, len).to(device).unsqueeze(1)
if args.dilation:
- model = NetToy1dWithDilation(nb_classes = 2 * len).to(device)
+ model = NetToy1dWithDilation(nb_classes=2 * len).to(device)
else:
- model = NetToy1d(nb_classes = 2 * len).to(device)
+ model = NetToy1d(nb_classes=2 * len).to(device)
-elif args.data == 'mnist':
- train_set = torchvision.datasets.MNIST('./data/mnist/', train = True, download = True)
+elif args.data == "mnist":
+ train_set = torchvision.datasets.MNIST("./data/mnist/", train=True, download=True)
train_input = train_set.data.view(-1, 1, 28, 28).long().to(device)
- model = PixelCNN(nb_classes = 256, in_channels = 1).to(device)
+ model = PixelCNN(nb_classes=256, in_channels=1).to(device)
in_channels = train_input.size(1)
if args.positional:
positional_input = positional_tensor(height, width).float().to(device)
in_channels += positional_input.size(1)
- model = PixelCNN(nb_classes = 256, in_channels = in_channels).to(device)
+ model = PixelCNN(nb_classes=256, in_channels=in_channels).to(device)
else:
- raise ValueError('Unknown data ' + args.data)
+ raise ValueError("Unknown data " + args.data)
######################################################################
mean, std = train_input.float().mean(), train_input.float().std()
nb_parameters = sum(t.numel() for t in model.parameters())
-log_string(f'nb_parameters {nb_parameters}')
+log_string(f"nb_parameters {nb_parameters}")
cross_entropy = nn.CrossEntropyLoss().to(device)
-optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
+optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
for e in range(args.nb_epochs):
-
nb_batches, acc_loss = 0, 0.0
for sequences in train_input.split(args.batch_size):
- input = (sequences - mean)/std
+ input = (sequences - mean) / std
if args.positional:
input = torch.cat(
- (input, positional_input.expand(input.size(0), -1, -1, -1)),
- 1
+ (input, positional_input.expand(input.size(0), -1, -1, -1)), 1
)
output = model(input)
- loss = cross_entropy(
- output.view(-1, output.size(-1)),
- sequences.view(-1)
- )
+ loss = cross_entropy(output.view(-1, output.size(-1)), sequences.view(-1))
optimizer.zero_grad()
loss.backward()
nb_batches += 1
acc_loss += loss.item()
- log_string(f'{e} {acc_loss / nb_batches} {math.exp(acc_loss / nb_batches)}')
+ log_string(f"{e} {acc_loss / nb_batches} {math.exp(acc_loss / nb_batches)}")
sys.stdout.flush()
for t in range(flat.size(1)):
input = (generated.float() - mean) / std
if args.positional:
- input = torch.cat((input, positional_input.expand(input.size(0), -1, -1, -1)), 1)
+ input = torch.cat(
+ (input, positional_input.expand(input.size(0), -1, -1, -1)), 1
+ )
output = model(input)
logits = output.view(flat.size() + (-1,))[:, t]
- dist = torch.distributions.categorical.Categorical(logits = logits)
+ dist = torch.distributions.categorical.Categorical(logits=logits)
flat[:, t] = dist.sample()
######################################################################
-if args.data == 'toy1d':
-
- with open('causalar-' + str_experiment + '-train.dat', 'w') as file:
+if args.data == "toy1d":
+ with open("causalar-" + str_experiment + "-train.dat", "w") as file:
for j in range(train_input.size(2)):
- file.write(f'{j}')
+ file.write(f"{j}")
for i in range(min(train_input.size(0), 25)):
- file.write(f' {train_input[i, 0, j]}')
- file.write('\n')
+ file.write(f" {train_input[i, 0, j]}")
+ file.write("\n")
- with open('causalar-' + str_experiment + '-generated.dat', 'w') as file:
+ with open("causalar-" + str_experiment + "-generated.dat", "w") as file:
for j in range(generated.size(2)):
- file.write(f'{j}')
+ file.write(f"{j}")
for i in range(generated.size(0)):
- file.write(f' {generated[i, 0, j]}')
- file.write('\n')
-
-elif args.data == 'mnist':
+ file.write(f" {generated[i, 0, j]}")
+ file.write("\n")
- img_train = 1 - train_input[:generated.size(0)].float() / 255
+elif args.data == "mnist":
+ img_train = 1 - train_input[: generated.size(0)].float() / 255
img_generated = 1 - generated.float() / 255
- save_images(img_train, 'causalar-' + str_experiment + '-train.png', nrow = 12)
- save_images(img_generated, 'causalar-' + str_experiment + '-generated.png', nrow = 12)
+ save_images(img_train, "causalar-" + str_experiment + "-train.png", nrow=12)
+ save_images(img_generated, "causalar-" + str_experiment + "-generated.png", nrow=12)
######################################################################