- self.w_G = randw(nb_heads, caterpillar_height, dim_model, amplitude=1e-5)
- self.b_G = nn.Parameter(
- torch.full(
- (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
- )
- )
+ x = kwargs.get("gate_dropout")
+ if x is None:
+ self.proba_gate_dropout = 0.0
+ else:
+ self.proba_gate_dropout = float(x)
+
+ logger(f"self.proba_gate_dropout {self.proba_gate_dropout}")
+
+ x = kwargs.get("default_bg")
+ if x is None:
+ default_bg = -math.log(caterpillar_height - 1)
+ else:
+ default_bg = float(x)
+
+ logger(f"default_bg {default_bg}")
+
+ ######################################################################
+
+ self.w_G = randw(nb_heads, caterpillar_height, dim_model)
+ self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))