import torch, math, sys, argparse
-from torch import nn
+from torch import nn, einsum
from torch.nn import functional as F
+import matplotlib.pyplot as plt
+
######################################################################
parser = argparse.ArgumentParser(description='Toy attention model.')
help = 'Provide a positional encoding',
action='store_true', default=False)
+parser.add_argument('--seed',
+ type = int, default = 0,
+ help = 'Random seed (default 0, < 0 is no seeding)')
+
args = parser.parse_args()
+if args.seed >= 0:
+ torch.manual_seed(args.seed)
+
######################################################################
label=''
else:
device = torch.device('cpu')
-torch.manual_seed(1)
-
######################################################################
seq_height_min, seq_height_max = 1.0, 25.0
seq_length = 100
def positions_to_sequences(tr = None, bx = None, noise_level = 0.3):
- st = torch.arange(seq_length).float()
+ st = torch.arange(seq_length, device = device).float()
st = st[None, :, None]
tr = tr[:, None, :, :]
bx = bx[:, None, :, :]
x = torch.cat((xtr, xbx), 2)
- # u = x.sign()
u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size = 2, stride = 1).permute(0, 2, 1)
collisions = (u.sum(2) > 1).max(1).values
# Position / height / width
- tr = torch.empty(nb, 2, 3)
+ tr = torch.empty(nb, 2, 3, device = device)
tr[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
tr[:, :, 1].uniform_(seq_height_min, seq_height_max)
tr[:, :, 2].uniform_(seq_width_min, seq_width_max)
- bx = torch.empty(nb, 2, 3)
+ bx = torch.empty(nb, 2, 3, device = device)
bx[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
bx[:, :, 1].uniform_(seq_height_min, seq_height_max)
bx[:, :, 2].uniform_(seq_width_min, seq_width_max)
######################################################################
-import matplotlib.pyplot as plt
-
def save_sequence_images(filename, sequences, tr = None, bx = None):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
delta = -1.
if tr is not None:
- ax.scatter(test_tr[k, :, 0], torch.full((test_tr.size(1),), delta), color = 'black', marker = '^', clip_on=False)
+ ax.scatter(tr[:, 0].cpu(), torch.full((tr.size(0),), delta), color = 'black', marker = '^', clip_on=False)
if bx is not None:
- ax.scatter(test_bx[k, :, 0], torch.full((test_bx.size(1),), delta), color = 'black', marker = 's', clip_on=False)
+ ax.scatter(bx[:, 0].cpu(), torch.full((bx.size(0),), delta), color = 'black', marker = 's', clip_on=False)
fig.savefig(filename, bbox_inches='tight')
class AttentionLayer(nn.Module):
def __init__(self, in_channels, out_channels, key_channels):
- super(AttentionLayer, self).__init__()
+ super().__init__()
self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False)
Q = self.conv_Q(x)
K = self.conv_K(x)
V = self.conv_V(x)
- A = Q.permute(0, 2, 1).matmul(K).softmax(2)
- x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1)
- return x
+ A = einsum('nct,ncs->nts', Q, K).softmax(2)
+ y = einsum('nts,ncs->nct', A, V)
+ return y
def __repr__(self):
return self._get_name() + \
def attention(self, x):
Q = self.conv_Q(x)
K = self.conv_K(x)
- return Q.permute(0, 2, 1).matmul(K).softmax(2)
+ A = einsum('nct,ncs->nts', Q, K).softmax(2)
+ return A
######################################################################
test_outputs = model((test_input - mu) / std).detach()
if args.with_attention:
- x = model[0:4]((test_input - mu) / std)
- test_A = model[4].attention(x)
+ k = next(k for k, l in enumerate(model) if isinstance(l, AttentionLayer))
+ x = model[0:k]((test_input - mu) / std)
+ test_A = model[k].attention(x)
test_A = test_A.detach().to('cpu')
test_input = test_input.detach().to('cpu')
test_outputs = test_outputs.detach().to('cpu')
test_targets = test_targets.detach().to('cpu')
+test_bx = test_bx.detach().to('cpu')
+test_tr = test_tr.detach().to('cpu')
for k in range(15):
save_sequence_images(