warnings.warn("Caterpillar", RuntimeWarning)
- def randw(*d):
- return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+ def randw(*d, amplitude=None):
+ if amplitude is None:
+ amplitude = 1 / math.sqrt(d[-1])
+ return nn.Parameter(amplitude * torch.randn(*d))
self.caterpillar_length = caterpillar_length
self.caterpillar_height = caterpillar_height
self.proba_gate_dropout = 0.0
- self.w_G = randw(nb_heads, caterpillar_height, dim_model)
+ self.w_G = randw(nb_heads, caterpillar_height, dim_model, amplitude=1e-5)
self.b_G = nn.Parameter(
torch.full(
(nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
self.w_Q = randw(nb_heads, dim_qk, dim_model)
self.w_O = randw(dim_v * nb_heads, dim_model)
- self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk)
- self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v)
+ self.init_K_rec = randw(
+ caterpillar_height, caterpillar_length, dim_qk, amplitude=1e-5
+ )
+ self.init_V_rec = randw(
+ caterpillar_height, caterpillar_length, dim_v, amplitude=1e-5
+ )
def reset_inner_loss(self):
self.acc_attention = 0
# recurrent state, or not at all.
G = (
- torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
+ torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
).sigmoid()
+ ######################################################################
+ # The "flashbacks"
+
+ if self.training and self.proba_gate_dropout > 0.0:
+ # This is a better implementation of "flashbacks".
+
+ # G is NxHxExT where e is the caterpillar's row.
+
+ warnings.warn("gate dropout", RuntimeWarning)
+ epsilon = 0.5
+
+ dropout_start = (
+ (
+ torch.rand(G.size(), device=G.device)
+ .flatten(2, 3)
+ .sort(dim=2)
+ .indices
+ == 0
+ )
+ .unflatten(2, (CH, t1 - t0))
+ .float()
+ )
+
+ dropout_tail = dropout_start.cumsum(dim=3) - dropout_start
+
+ dropout_active = (
+ torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
+ ).long()
+
+ dropout_start *= dropout_active
+ dropout_tail *= dropout_active
+
+ G = (
+ G
+ + dropout_start * (1 - epsilon - G.detach())
+ - dropout_tail * G.detach()
+ )
+
+ ######################################################################
+
+ # We prepare the arguments for the parallel scan
+
# Clip the gating to avoid values greater than 1 when several
# heads hit the same row
G = G / G.sum(1, keepdim=True).clamp(min=1)
- # We prepare the arguments for the parallel scan
-
A = 1 - G.sum(1)
- gated_V = torch.einsum("nhet,nhtd->netd", G, V)
- gated_K = torch.einsum("nhet,nhtd->netd", G, K)
+ gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
+ gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
# We start from cached values, which matters in inference
init_rec_V = self.rec_V[:, :, t0 - CL : t0]
init_rec_K = self.rec_K[:, :, t0 - CL : t0]
- ######################################################################
-
- if self.training and self.proba_gate_dropout > 0.0:
- # This is a better implementation of "flashbacks". A is
- # NxExT where e is the caterpillar's row.
-
- warnings.warn("gate dropout", RuntimeWarning)
- epsilon = 0.5
-
#################################################################
# Associative scan