##############################
+class NaNChecker(nn.Module):
+ def __init__(self, name):
+ super().__init__()
+ self.name = name
+
+ def forward(self, bs):
+ x = bs.x if type(bs) is BracketedSequence else bs
+ assert not x.isnan().any(), f"${self.name} detected NaN"
+ assert not x.isinf().any(), f"${self.name} detected Inf"
+ return bs
+
+
class WithResidual(nn.Module):
def __init__(self, *f):
super().__init__()
self.w_qw = randw(nb_heads, dim_qk, dim_model)
self.w_qr = randw(nb_heads, dim_qk, dim_model)
- # self.w_k = randw(nb_heads, dim_qk, dim_model)
self.w_v = randw(nb_heads, dim_v, dim_model)
self.w_o = randw(dim_v * nb_heads, dim_model)
- def reset_inner_loss(self):
- self.acc_attention = 0
- self.acc_nb = 0
-
- def get_inner_loss(self):
- warnings.warn("l2 regularization", RuntimeWarning)
- return (self.acc_attention / self.acc_nb).pow(2).sum()
- # return torch.tensor([0], device=self.w_qw.device)
-
def forward(self, bs):
x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
self.rec_v = x_q.new_zeros(
x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
)
- # self.rec_k = x_q.new_zeros(
- # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
- # )
self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
- ######################################################################
- # Prepare the keys
-
- k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
-
- warnings.warn("rotating key barrel", RuntimeWarning)
- k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
- t_barrel = torch.arange(t0, t1, device=k_star.device)
- t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
- l_barrel = (
- torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
- ) % k_star.size(0)
- k_star = k_star[l_barrel, t_barrel]
-
######################################################################
# Compute the recurrent state
qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
- # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
- aw = torch.einsum(
- "nhtd,ltd->nhlt",
- qw,
- k_star,
- ) / math.sqrt(self.w_qw.size(1))
+ aw = torch.einsum("nhtd,ld->nhlt", qw, self.k_star) / math.sqrt(
+ self.w_qw.size(1)
+ )
aw = aw.softmax(dim=2) # nhlt
- if self.train:
- self.acc_attention += aw.sum(dim=(0, 1, 3))
- self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
-
aw = F.dropout(aw, self.attention_dropout, self.training)
A = 1 - aw.sum(dim=1) # nlt
V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
- # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
if t0 == 0:
V0 = None
- # K0 = None
else:
V0 = self.rec_v[:, :, t0 - 1]
- # K0 = self.rec_k[:, :, t0 - 1]
self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
- # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
######################################################################
# compute the readout
ar = torch.einsum(
"nhtd,ld->nhlt",
qr,
- # self.rec_k[:, :, t0:t1],
self.k_star,
) / math.sqrt(self.w_qr.size(1))
self.acc_nb = 0
def get_inner_loss(self):
- warnings.warn("l2 regularization", RuntimeWarning)
- return (self.acc_attention / self.acc_nb).pow(2).sum()
- # return torch.tensor([0], device=self.w_qw.device)
+ # warnings.warn("l2 regularization", RuntimeWarning)
+ # return (self.acc_attention / self.acc_nb).pow(2).sum()
+ return torch.tensor([0], device=self.w_qw.device)
# warnings.warn("side regularization", RuntimeWarning)
# return (
# (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
- warnings.warn("rotating key barrel", RuntimeWarning)
+ # warnings.warn("rotating key barrel", RuntimeWarning)
k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
t_barrel = torch.arange(t0, t1, device=k_star.device)
t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
l_barrel = (
- torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
+ torch.arange(k_star.size(0), device=k_star.device)[:, None] # + t_barrel
) % k_star.size(0)
k_star = k_star[l_barrel, t_barrel]
):
super().__init__()
+ self.vocabulary_size = vocabulary_size
+
assert attention_layer in {
"mha",
"dumbrec",