self.proba_gate_dropout = 0.0
- self.w_G = randw(nb_heads, caterpillar_height, dim_model, amplitude=1e-5)
+ self.w_G = randw(nb_heads, caterpillar_height, dim_model)
self.b_G = nn.Parameter(
torch.full(
(nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
self.w_O = randw(dim_v * nb_heads, dim_model)
self.init_K_rec = randw(
- caterpillar_height, caterpillar_length, dim_qk, amplitude=1e-5
+ caterpillar_height,
+ caterpillar_length,
+ dim_qk,
)
self.init_V_rec = randw(
- caterpillar_height, caterpillar_length, dim_v, amplitude=1e-5
+ caterpillar_height,
+ caterpillar_length,
+ dim_v,
)
def reset_inner_loss(self):
warnings.warn("gate dropout", RuntimeWarning)
epsilon = 0.5
- dropout_start = (
- (
- torch.rand(G.size(), device=G.device)
- .flatten(2, 3)
- .sort(dim=2)
- .indices
- == 0
- )
- .unflatten(2, (CH, t1 - t0))
+ dropout_head = (
+ (torch.rand(N, H, 1, t1 - t0, device=G.device).sort(dim=3).indices == 0)
+ .expand_as(G)
.float()
)
- dropout_tail = dropout_start.cumsum(dim=3) - dropout_start
+ dropout_tail = dropout_head.cumsum(dim=3) - dropout_head
dropout_active = (
torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
).long()
- dropout_start *= dropout_active
+ dropout_head *= dropout_active
dropout_tail *= dropout_active
G = (
G
- + dropout_start * (1 - epsilon - G.detach())
+ # + dropout_head * (1 - epsilon - G.detach())
- dropout_tail * G.detach()
)