super().__init__()
self.len_max = len_max
- # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
-
def forward(self, x):
t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
- k = j % 2
-
+ k = j % 2 # works with float, weird
pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
-
y = x + pe
-
return y
######################################################################
-def vanilla_attention(q, k, v):
+def attention(q, k, v):
a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
a = a.softmax(dim=3)
y = torch.einsum("nhts,nhsd->nhtd", a, v)
- y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
return y
-vanilla_attention = torch.compile(vanilla_attention)
+attention = torch.compile(attention)
-# y = flex_attention(q, k, v, score_mod=noop)
+######################################################################
class MHAttention(nn.Module):
def __init__(
self,
- dim_in,
+ dim_model,
dim_qk,
dim_v,
nb_heads=1,
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
self.attention_dropout = attention_dropout
- self.record_attention = False
-
- self.w_q = randw(nb_heads, dim_qk, dim_in)
- self.w_k = randw(nb_heads, dim_qk, dim_in)
- self.w_v = randw(nb_heads, dim_v, dim_in)
- self.w_o = randw(nb_heads, dim_v, dim_in)
+ self.w_q = randw(nb_heads, dim_qk, dim_model)
+ self.w_k = randw(nb_heads, dim_qk, dim_model)
+ self.w_v = randw(nb_heads, dim_v, dim_model)
+ self.w_o = randw(nb_heads, dim_v, dim_model)
def forward(self, x_q, x_kv=None):
if x_kv is None:
q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
-
- def noop(score, b, h, q_idx, kv_idx):
- return score
-
- y = vanilla_attention(q, k, v)
- # y = flex_attention(q, k, v, score_mod=noop)
-
+ y = attention(q, k, v)
y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
return y
nb_heads,
nb_blocks,
dropout=0.0,
- len_max=1024,
+ len_max=1e5,
):
super().__init__()
nn.Dropout(dropout),
)
- self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
+ self.positional_encoding = VaswaniPositionalEncoding(len_max)
trunk_blocks = []
WithResidual(
nn.LayerNorm((dim_model,)),
MHAttention(
- dim_in=dim_model,
+ dim_model=dim_model,
dim_qk=dim_keys,
dim_v=dim_model // nb_heads,
nb_heads=nb_heads,
c_quizzes = torch.cat(record_c_quizzes, dim=0)
agreements = torch.cat(record_agreements, dim=0)
- return c_quizzes, agreements
-
-
-def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
- record.append(generate_ae_c_quizzes(models, nb, local_device))
+ return c_quizzes.to("cpu"), agreements.to("cpu")
######################################################################
else:
return [
- torch.cat([x[k].to("cpu") for x in records], dim=0)
- for k in range(len(records[0]))
+ torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
]