3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 # ./causal-autoregression.py --data=toy1d
9 # ./causal-autoregression.py --data=toy1d --dilation
10 # ./causal-autoregression.py --data=mnist
11 # ./causal-autoregression.py --data=mnist --positional
13 import argparse, math, sys, time
14 import torch, torchvision
17 from torch.nn import functional as F
19 ######################################################################
22 def save_images(x, filename, nrow=12):
23 print(f"Writing {filename}")
24 torchvision.utils.save_image(
25 x.narrow(0, 0, min(48, x.size(0))), filename, nrow=nrow, pad_value=1.0
29 ######################################################################
31 parser = argparse.ArgumentParser(
32 description="An implementation of a causal autoregression model",
33 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
36 parser.add_argument("--data", type=str, default="toy1d", help="What data")
39 "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
42 parser.add_argument("--nb_epochs", type=int, default=-1, help="How many epochs")
44 parser.add_argument("--batch_size", type=int, default=100, help="Batch size")
46 parser.add_argument("--learning_rate", type=float, default=1e-3, help="Batch size")
52 help="Do we provide a positional encoding as input",
59 help="Do we provide a positional encoding as input",
62 ######################################################################
64 args = parser.parse_args()
67 torch.manual_seed(args.seed)
69 if args.nb_epochs < 0:
70 if args.data == "toy1d":
72 elif args.data == "mnist":
75 ######################################################################
77 if torch.cuda.is_available():
78 print("Cuda is available")
79 device = torch.device("cuda")
80 torch.backends.cudnn.benchmark = True
82 device = torch.device("cpu")
84 ######################################################################
87 class NetToy1d(nn.Module):
88 def __init__(self, nb_classes, ks=2, nc=32):
90 self.pad = (ks - 1, 0)
91 self.conv0 = nn.Conv1d(1, nc, kernel_size=1)
92 self.conv1 = nn.Conv1d(nc, nc, kernel_size=ks)
93 self.conv2 = nn.Conv1d(nc, nc, kernel_size=ks)
94 self.conv3 = nn.Conv1d(nc, nc, kernel_size=ks)
95 self.conv4 = nn.Conv1d(nc, nc, kernel_size=ks)
96 self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size=1)
99 x = F.relu(self.conv0(F.pad(x, (1, -1))))
100 x = F.relu(self.conv1(F.pad(x, self.pad)))
101 x = F.relu(self.conv2(F.pad(x, self.pad)))
102 x = F.relu(self.conv3(F.pad(x, self.pad)))
103 x = F.relu(self.conv4(F.pad(x, self.pad)))
105 return x.permute(0, 2, 1).contiguous()
108 class NetToy1dWithDilation(nn.Module):
109 def __init__(self, nb_classes, ks=2, nc=32):
111 self.conv0 = nn.Conv1d(1, nc, kernel_size=1)
112 self.pad1 = ((ks - 1) * 2, 0)
113 self.conv1 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=2)
114 self.pad2 = ((ks - 1) * 4, 0)
115 self.conv2 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=4)
116 self.pad3 = ((ks - 1) * 8, 0)
117 self.conv3 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=8)
118 self.pad4 = ((ks - 1) * 16, 0)
119 self.conv4 = nn.Conv1d(nc, nc, kernel_size=ks, dilation=16)
120 self.conv5 = nn.Conv1d(nc, nb_classes, kernel_size=1)
122 def forward(self, x):
123 x = F.relu(self.conv0(F.pad(x, (1, -1))))
124 x = F.relu(self.conv1(F.pad(x, self.pad2)))
125 x = F.relu(self.conv2(F.pad(x, self.pad3)))
126 x = F.relu(self.conv3(F.pad(x, self.pad4)))
127 x = F.relu(self.conv4(F.pad(x, self.pad5)))
129 return x.permute(0, 2, 1).contiguous()
132 ######################################################################
135 class PixelCNN(nn.Module):
136 def __init__(self, nb_classes, in_channels=1, ks=5):
139 self.hpad = (ks // 2, ks // 2, ks // 2, 0)
140 self.vpad = (ks // 2, 0, 0, 0)
142 self.conv1h = nn.Conv2d(in_channels, 32, kernel_size=(ks // 2 + 1, ks))
143 self.conv2h = nn.Conv2d(32, 64, kernel_size=(ks // 2 + 1, ks))
144 self.conv1v = nn.Conv2d(in_channels, 32, kernel_size=(1, ks // 2 + 1))
145 self.conv2v = nn.Conv2d(32, 64, kernel_size=(1, ks // 2 + 1))
146 self.final1 = nn.Conv2d(128, 128, kernel_size=1)
147 self.final2 = nn.Conv2d(128, nb_classes, kernel_size=1)
149 def forward(self, x):
150 xh = F.pad(x, (0, 0, 1, -1))
151 xv = F.pad(x, (1, -1, 0, 0))
152 xh = F.relu(self.conv1h(F.pad(xh, self.hpad)))
153 xv = F.relu(self.conv1v(F.pad(xv, self.vpad)))
154 xh = F.relu(self.conv2h(F.pad(xh, self.hpad)))
155 xv = F.relu(self.conv2v(F.pad(xv, self.vpad)))
156 x = F.relu(self.final1(torch.cat((xh, xv), 1)))
159 return x.permute(0, 2, 3, 1).contiguous()
162 ######################################################################
165 def positional_tensor(height, width):
166 index_h = torch.arange(height).view(1, -1)
167 m_h = (2 ** torch.arange(math.ceil(math.log2(height)))).view(-1, 1)
168 b_h = (index_h // m_h) % 2
169 i_h = b_h[None, :, None, :].expand(-1, -1, height, -1)
171 index_w = torch.arange(width).view(1, -1)
172 m_w = (2 ** torch.arange(math.ceil(math.log2(width)))).view(-1, 1)
173 b_w = (index_w // m_w) % 2
174 i_w = b_w[None, :, :, None].expand(-1, -1, -1, width)
176 return torch.cat((i_w, i_h), 1)
179 ######################################################################
181 str_experiment = args.data
184 str_experiment += "-positional"
187 str_experiment += "-dilation"
189 log_file = open("causalar-" + str_experiment + "-train.log", "w")
193 s = time.strftime("%Y%m%d-%H:%M:%S", time.localtime()) + " " + s
195 log_file.write(s + "\n")
199 ######################################################################
202 def generate_sequences(nb, len):
205 r = torch.empty(nb, len)
207 x = torch.empty(nb, nb_parts).uniform_(-1, 1)
208 x = x.view(nb, nb_parts, 1).expand(nb, nb_parts, len)
209 x = x * torch.linspace(0, len - 1, len).view(1, -1) + len
212 a = torch.randperm(len - 2)[: nb_parts + 1].sort()[0]
214 a[a.size(0) - 1] = len
215 for k in range(a.size(0) - 1):
216 r[n, a[k] : a[k + 1]] = x[n, k, : a[k + 1] - a[k]]
218 return r.round().long()
221 ######################################################################
223 if args.data == "toy1d":
225 train_input = generate_sequences(50000, len).to(device).unsqueeze(1)
227 model = NetToy1dWithDilation(nb_classes=2 * len).to(device)
229 model = NetToy1d(nb_classes=2 * len).to(device)
231 elif args.data == "mnist":
232 train_set = torchvision.datasets.MNIST("./data/mnist/", train=True, download=True)
233 train_input = train_set.data.view(-1, 1, 28, 28).long().to(device)
235 model = PixelCNN(nb_classes=256, in_channels=1).to(device)
236 in_channels = train_input.size(1)
239 height, width = train_input.size(2), train_input.size(3)
240 positional_input = positional_tensor(height, width).float().to(device)
241 in_channels += positional_input.size(1)
243 model = PixelCNN(nb_classes=256, in_channels=in_channels).to(device)
246 raise ValueError("Unknown data " + args.data)
248 ######################################################################
250 mean, std = train_input.float().mean(), train_input.float().std()
252 nb_parameters = sum(t.numel() for t in model.parameters())
253 log_string(f"nb_parameters {nb_parameters}")
255 cross_entropy = nn.CrossEntropyLoss().to(device)
256 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
258 for e in range(args.nb_epochs):
259 nb_batches, acc_loss = 0, 0.0
261 for sequences in train_input.split(args.batch_size):
262 input = (sequences - mean) / std
266 (input, positional_input.expand(input.size(0), -1, -1, -1)), 1
269 output = model(input)
271 loss = cross_entropy(output.view(-1, output.size(-1)), sequences.view(-1))
273 optimizer.zero_grad()
278 acc_loss += loss.item()
280 log_string(f"{e} {acc_loss / nb_batches} {math.exp(acc_loss / nb_batches)}")
284 ######################################################################
286 generated = train_input.new_zeros((48,) + train_input.size()[1:])
288 flat = generated.view(generated.size(0), -1)
290 for t in range(flat.size(1)):
291 input = (generated.float() - mean) / std
294 (input, positional_input.expand(input.size(0), -1, -1, -1)), 1
296 output = model(input)
297 logits = output.view(flat.size() + (-1,))[:, t]
298 dist = torch.distributions.categorical.Categorical(logits=logits)
299 flat[:, t] = dist.sample()
301 ######################################################################
303 if args.data == "toy1d":
304 with open("causalar-" + str_experiment + "-train.dat", "w") as file:
305 for j in range(train_input.size(2)):
307 for i in range(min(train_input.size(0), 25)):
308 file.write(f" {train_input[i, 0, j]}")
311 with open("causalar-" + str_experiment + "-generated.dat", "w") as file:
312 for j in range(generated.size(2)):
314 for i in range(generated.size(0)):
315 file.write(f" {generated[i, 0, j]}")
318 elif args.data == "mnist":
319 img_train = 1 - train_input[: generated.size(0)].float() / 255
320 img_generated = 1 - generated.float() / 255
322 save_images(img_train, "causalar-" + str_experiment + "-train.png", nrow=12)
323 save_images(img_generated, "causalar-" + str_experiment + "-generated.png", nrow=12)
325 ######################################################################