+ if self.training and self.proba_gate_dropout > 0.0:
+ # G is NxHxRxT where r is the caterpillar's row.
+
+ warnings.warn("gate dropout", RuntimeWarning)
+
+ kill = (
+ torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout
+ ).float()
+
+ mask = 1 - kill
+
+ masked_next_V, masked_next_K = recurrence(G * mask, V, K)
+
+ next_V = next_V.detach() + (masked_next_V - masked_next_V.detach()) / (
+ 1 - self.proba_gate_dropout
+ )
+ next_K = next_K.detach() + (masked_next_K - masked_next_K.detach()) / (
+ 1 - self.proba_gate_dropout
+ )
+