3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import torch, math, sys, argparse
10 from torch import nn, einsum
11 from torch.nn import functional as F
13 import matplotlib.pyplot as plt
15 ######################################################################
17 parser = argparse.ArgumentParser(description="Toy attention model.")
19 parser.add_argument("--nb_epochs", type=int, default=250)
23 help="Use the model with an attention layer",
29 "--group_by_locations",
30 help="Use the task where the grouping is location-based",
36 "--positional_encoding",
37 help="Provide a positional encoding",
43 "--seed", type=int, default=0, help="Random seed (default 0, < 0 is no seeding)"
46 args = parser.parse_args()
49 torch.manual_seed(args.seed)
51 ######################################################################
55 if args.with_attention:
58 if args.group_by_locations:
61 if args.positional_encoding:
64 log_file = open(f"att1d_{label}train.log", "w")
66 ######################################################################
70 if log_file is not None:
71 log_file.write(s + "\n")
77 ######################################################################
79 if torch.cuda.is_available():
80 device = torch.device("cuda")
81 torch.backends.cudnn.benchmark = True
83 device = torch.device("cpu")
85 ######################################################################
87 seq_height_min, seq_height_max = 1.0, 25.0
88 seq_width_min, seq_width_max = 5.0, 11.0
92 def positions_to_sequences(tr=None, bx=None, noise_level=0.3):
93 st = torch.arange(seq_length, device=device).float()
94 st = st[None, :, None]
95 tr = tr[:, None, :, :]
96 bx = bx[:, None, :, :]
100 - torch.relu(torch.abs(st - tr[..., 0]) - 0.5) * 2 * tr[..., 1] / tr[..., 2]
105 bx[..., 1] - torch.abs((st - bx[..., 0]) * 2 * bx[..., 1] / bx[..., 2])
111 x = torch.cat((xtr, xbx), 2)
113 u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size=2, stride=1).permute(
117 collisions = (u.sum(2) > 1).max(1).values
120 return y + torch.rand_like(y) * noise_level - noise_level / 2, collisions
123 ######################################################################
126 def generate_sequences(nb):
127 # Position / height / width
129 tr = torch.empty(nb, 2, 3, device=device)
130 tr[:, :, 0].uniform_(seq_width_max / 2, seq_length - seq_width_max / 2)
131 tr[:, :, 1].uniform_(seq_height_min, seq_height_max)
132 tr[:, :, 2].uniform_(seq_width_min, seq_width_max)
134 bx = torch.empty(nb, 2, 3, device=device)
135 bx[:, :, 0].uniform_(seq_width_max / 2, seq_length - seq_width_max / 2)
136 bx[:, :, 1].uniform_(seq_height_min, seq_height_max)
137 bx[:, :, 2].uniform_(seq_width_min, seq_width_max)
139 if args.group_by_locations:
140 a = torch.cat((tr, bx), 1)
141 v = a[:, :, 0].sort(1).values[:, 2:3]
142 mask_left = (a[:, :, 0] < v).float()
143 h_left = (a[:, :, 1] * mask_left).sum(1) / 2
144 h_right = (a[:, :, 1] * (1 - mask_left)).sum(1) / 2
145 valid = (h_left - h_right).abs() > 4
147 valid = (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4) & (
148 torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4
151 input, collisions = positions_to_sequences(tr, bx)
153 if args.group_by_locations:
154 a = torch.cat((tr, bx), 1)
155 v = a[:, :, 0].sort(1).values[:, 2:3]
156 mask_left = (a[:, :, 0] < v).float()
157 h_left = (a[:, :, 1] * mask_left).sum(1, keepdim=True) / 2
158 h_right = (a[:, :, 1] * (1 - mask_left)).sum(1, keepdim=True) / 2
159 a[:, :, 1] = mask_left * h_left + (1 - mask_left) * h_right
160 tr, bx = a.split(2, 1)
162 tr[:, :, 1:2] = tr[:, :, 1:2].mean(1, keepdim=True)
163 bx[:, :, 1:2] = bx[:, :, 1:2].mean(1, keepdim=True)
165 targets, _ = positions_to_sequences(tr, bx)
167 valid = valid & ~collisions
170 input = input[valid][:, None, :]
171 targets = targets[valid][:, None, :]
173 if input.size(0) < nb:
174 input2, targets2, tr2, bx2 = generate_sequences(nb - input.size(0))
175 input = torch.cat((input, input2), 0)
176 targets = torch.cat((targets, targets2), 0)
177 tr = torch.cat((tr, tr2), 0)
178 bx = torch.cat((bx, bx2), 0)
180 return input, targets, tr, bx
183 ######################################################################
186 def save_sequence_images(filename, sequences, tr=None, bx=None):
188 ax = fig.add_subplot(1, 1, 1)
190 ax.set_xlim(0, seq_length)
191 ax.set_ylim(-1, seq_height_max + 4)
194 ax.plot(torch.arange(u[0].size(0)) + 0.5, u[0], color=u[1], label=u[2])
196 ax.legend(frameon=False, loc="upper left")
202 torch.full((tr.size(0),), delta),
211 torch.full((bx.size(0),), delta),
217 fig.savefig(filename, bbox_inches="tight")
222 ######################################################################
225 class AttentionLayer(nn.Module):
226 def __init__(self, in_channels, out_channels, key_channels):
228 self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size=1, bias=False)
229 self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size=1, bias=False)
230 self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
232 def forward(self, x):
236 A = einsum("nct,ncs->nts", Q, K).softmax(2)
237 y = einsum("nts,ncs->nct", A, V)
243 + "(in_channels={}, out_channels={}, key_channels={})".format(
244 self.conv_Q.in_channels,
245 self.conv_V.out_channels,
246 self.conv_K.out_channels,
250 def attention(self, x):
253 A = einsum("nct,ncs->nts", Q, K).softmax(2)
257 ######################################################################
259 train_input, train_targets, train_tr, train_bx = generate_sequences(25000)
260 test_input, test_targets, test_tr, test_bx = generate_sequences(1000)
262 ######################################################################
267 if args.positional_encoding:
268 c = math.ceil(math.log(seq_length) / math.log(2.0))
270 torch.arange(seq_length).unsqueeze(0) // 2 ** torch.arange(c).unsqueeze(1)
272 positional_input = positional_input.unsqueeze(0).float()
274 positional_input = torch.zeros(1, 0, seq_length)
276 in_channels = 1 + positional_input.size(1)
278 if args.with_attention:
279 model = nn.Sequential(
280 nn.Conv1d(in_channels, nc, kernel_size=ks, padding=ks // 2),
282 nn.Conv1d(nc, nc, kernel_size=ks, padding=ks // 2),
284 AttentionLayer(nc, nc, nc),
285 nn.Conv1d(nc, nc, kernel_size=ks, padding=ks // 2),
287 nn.Conv1d(nc, 1, kernel_size=ks, padding=ks // 2),
291 model = nn.Sequential(
292 nn.Conv1d(in_channels, nc, kernel_size=ks, padding=ks // 2),
294 nn.Conv1d(nc, nc, kernel_size=ks, padding=ks // 2),
296 nn.Conv1d(nc, nc, kernel_size=ks, padding=ks // 2),
298 nn.Conv1d(nc, nc, kernel_size=ks, padding=ks // 2),
300 nn.Conv1d(nc, 1, kernel_size=ks, padding=ks // 2),
303 nb_parameters = sum(p.numel() for p in model.parameters())
305 with open(f"att1d_{label}model.log", "w") as f:
306 f.write(str(model) + "\n\n")
307 f.write(f"nb_parameters {nb_parameters}\n")
309 ######################################################################
313 optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
314 mse_loss = nn.MSELoss()
318 train_input, train_targets = train_input.to(device), train_targets.to(device)
319 test_input, test_targets = test_input.to(device), test_targets.to(device)
320 positional_input = positional_input.to(device)
322 mu, std = train_input.mean(), train_input.std()
324 for e in range(args.nb_epochs):
327 for input, targets in zip(
328 train_input.split(batch_size), train_targets.split(batch_size)
330 input = torch.cat((input, positional_input.expand(input.size(0), -1, -1)), 1)
332 output = model((input - mu) / std)
333 loss = mse_loss(output, targets)
335 optimizer.zero_grad()
339 acc_loss += loss.item()
341 log_string(f"{e+1} {acc_loss}")
343 ######################################################################
345 train_input = train_input.detach().to("cpu")
346 train_targets = train_targets.detach().to("cpu")
349 save_sequence_images(
350 f"att1d_{label}train_{k:03d}.pdf",
352 (train_input[k, 0], "blue", "Input"),
353 (train_targets[k, 0], "red", "Target"),
359 test_input = torch.cat(
360 (test_input, positional_input.expand(test_input.size(0), -1, -1)), 1
362 test_outputs = model((test_input - mu) / std).detach()
364 if args.with_attention:
365 k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer))
366 x = model[0:k]((test_input - mu) / std)
367 test_A = model[k].attention(x)
368 test_A = test_A.detach().to("cpu")
370 test_input = test_input.detach().to("cpu")
371 test_outputs = test_outputs.detach().to("cpu")
372 test_targets = test_targets.detach().to("cpu")
373 test_bx = test_bx.detach().to("cpu")
374 test_tr = test_tr.detach().to("cpu")
377 save_sequence_images(
378 f"att1d_{label}test_Y_{k:03d}.pdf",
380 (test_input[k, 0], "blue", "Input"),
381 (test_outputs[k, 0], "orange", "Output"),
385 save_sequence_images(
386 f"att1d_{label}test_Yp_{k:03d}.pdf",
388 (test_input[k, 0], "blue", "Input"),
389 (test_outputs[k, 0], "orange", "Output"),
395 if args.with_attention:
397 ax = fig.add_subplot(1, 1, 1)
398 ax.set_xlim(0, seq_length)
399 ax.set_ylim(0, seq_length)
401 ax.imshow(test_A[k], cmap="binary", interpolation="nearest")
405 torch.full((test_bx.size(1),), delta),
411 torch.full((test_bx.size(1),), delta),
419 torch.full((test_tr.size(1),), delta),
425 torch.full((test_tr.size(1),), delta),
432 fig.savefig(f"att1d_{label}test_A_{k:03d}.pdf", bbox_inches="tight")
436 ######################################################################