nb_heads=2,
nb_blocks=5,
dropout=0.1,
- causal=True,
+ #causal=True,
)
model.eval()
attention_matrices = [m[0, 0] for m in model.retrieve_attention()]
+
+
# attention_matrices = [ torch.rand(3,5), torch.rand(8,3), torch.rand(5,8) ]
# for a in attention_matrices: a=a/a.sum(-1,keepdim=True)
return self.x[:, self.first : self.first + self.nb]
def complete(self):
- return self.first == 0 and self.nb == x.size(1)
+ return self.first == 0 and self.nb == self.x.size(1)
######################################################################