class QKVAttention(nn.Module):
def __init__(
- self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+ self,
+ dim_in,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ causal=False,
+ attention_dropout=0.0,
):
super().__init__()
self.causal = causal
self.attention_dropout = attention_dropout
+ self.record_attention = False
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
)
a = a.softmax(dim=3)
+
+ if self.record_attention:
+ self.a = a
+
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum(
t_next = dist.sample()
input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+ def record_attention(self, v=True):
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ m.record_attention = v
+
+ def retrieve_attention(self):
+ a = []
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ a.append(m.a)
+ return a
+
######################################################################
dim_keys=2,
dim_hidden=2,
nb_heads=2,
- nb_blocks=1,
+ nb_blocks=2,
dropout=0.1,
causal=True,
)
model.eval()
-
y1 = model(BracketedSequence(x)).x
y2 = torch.randn_like(y1)
for s in range(x.size(1)):