def slice(self):
return self.x[:, self.first : self.first + self.nb]
+ def complete(self):
+ return self.first == 0 and self.nb == x.size(1)
+
######################################################################
class QKVAttention(nn.Module):
def __init__(
- self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+ self,
+ dim_in,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ causal=False,
+ attention_dropout=0.0,
):
super().__init__()
def randw(*d):
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
- assert causal, "TODO: Switch off the cache when non-causal!!!"
self.causal = causal
self.attention_dropout = attention_dropout
+ self.record_attention = False
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
def forward(self, bs_q):
x_q = bs_q.x
+ assert (
+ self.causal or bs_q.complete()
+ ), "Partial evaluation is only possible for causal models"
+
if bs_q.first == 0:
self.cache_k = x_q.new_zeros(
x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
)
a = a.softmax(dim=3)
+
+ if self.record_attention:
+ self.a = a
+
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum(
t_next = dist.sample()
input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+ def record_attention(self, v=True):
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ m.record_attention = v
+
+ def retrieve_attention(self):
+ a = []
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ a.append(m.a)
+ return a
+
######################################################################
dim_keys=2,
dim_hidden=2,
nb_heads=2,
- nb_blocks=1,
+ nb_blocks=2,
dropout=0.1,
causal=True,
)
model.eval()
-
y1 = model(BracketedSequence(x)).x
y2 = torch.randn_like(y1)
for s in range(x.size(1)):