self.caterpillar_height = caterpillar_height
self.attention_dropout = attention_dropout
- self.proba_gate_dropout = 0.0
+ self.proba_gate_dropout = 0.25
self.w_G = randw(nb_heads, caterpillar_height, dim_model, amplitude=1e-5)
self.b_G = nn.Parameter(
warnings.warn("gate dropout", RuntimeWarning)
epsilon = 0.5
- dropout_start = (
+ dropout_head = (
(
torch.rand(G.size(), device=G.device)
.flatten(2, 3)
.float()
)
- dropout_tail = dropout_start.cumsum(dim=3) - dropout_start
+ dropout_tail = dropout_head.cumsum(dim=3) - dropout_head
dropout_active = (
torch.rand(N, 1, 1, 1, device=G.device) < self.proba_gate_dropout
).long()
- dropout_start *= dropout_active
+ dropout_head *= dropout_active
dropout_tail *= dropout_active
G = (
G
- + dropout_start * (1 - epsilon - G.detach())
+ # + dropout_head * (1 - epsilon - G.detach())
- dropout_tail * G.detach()
)