models = []
+
+def compute_causal_attzero(t_q, t_k):
+ return t_q < t_k
+
+
for k in range(args.nb_gpts):
log_string(f"creating model {k} and its w_quizzes")
dim_hidden=args.dim_hidden,
nb_heads=args.nb_heads,
nb_blocks=args.nb_blocks,
- causal=True,
+ compute_attzero=compute_causal_attzero,
dropout=args.dropout,
).to(main_device)
dim_hidden=args.dim_hidden,
nb_heads=args.nb_heads,
nb_blocks=args.nb_blocks,
- causal=True,
+ compute_attzero=compute_causal_attzero,
dropout=args.dropout,
).to(main_device)
dim_qk,
dim_v,
nb_heads=1,
- causal=False,
+ compute_attzero=None,
attention_dropout=0.0,
):
super().__init__()
def randw(*d):
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
- self.causal = causal
+ self.compute_attzero = compute_attzero
self.attention_dropout = attention_dropout
self.record_attention = False
def forward(self, bs_q):
x_q = bs_q.x
- assert (
- self.causal or bs_q.complete()
- ), "Partial evaluation is only possible for causal models"
-
if bs_q.first == 0:
self.cache_k = x_q.new_zeros(
x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
"nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
) / math.sqrt(self.w_q.size(1))
- if self.causal:
+ if self.compute_attzero is not None:
if bs_q.first == 0:
- self.cache_attzero = (
- torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
- < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
- )
+ self.cache_attzero = self.compute_attzero(
+ torch.arange(x_q.size(1), device=q.device)[:, None],
+ torch.arange(x_q.size(1), device=q.device)[None, :],
+ )[None, None, :, :]
a = a.masked_fill(
self.cache_attzero[
:, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
dim_hidden,
nb_heads,
nb_blocks,
- causal=False,
+ compute_attzero=None,
autoencoder_dim=-1,
dropout=0.0,
len_max=1e5,
dim_qk=dim_keys,
dim_v=dim_model // nb_heads,
nb_heads=nb_heads,
- causal=causal,
+ compute_attzero=compute_attzero,
attention_dropout=dropout,
),
),
nb_heads=2,
nb_blocks=2,
dropout=0.1,
- causal=True,
)
model.eval()