if amm_generator is None:
self.amm_generator = (
- lambda d: torch.arange(d)[None, None, :, None]
- < torch.arange(d)[None, None, None, :]
+ lambda d: torch.arange(d)[:, None]
+ < torch.arange(d)[None, :]
)
else:
self.amm_generator = amm_generator
if self.causal:
if bs_q.first == 0:
- self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)
+ self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)[None, None,:,:]
a = a.masked_fill(
self.cache_attzero[
:, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb