self.proba_gate_dropout = 0.0
- default_b_G = kwargs.get("default_b_G")
- if default_b_G is None:
- default_b_G = -math.log(caterpillar_height - 1)
+ default_bg = kwargs.get("default_bg")
+ if default_bg is None:
+ default_bg = -math.log(caterpillar_height - 1)
+ else:
+ default_bg = float(default_bg)
- logger(f"default_b_G {default_b_G}")
+ logger(f"default_bg {default_bg}")
self.w_G = randw(nb_heads, caterpillar_height, dim_model)
- self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_b_G))
+ self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
self.w_K = randw(nb_heads, dim_qk, dim_model)
self.w_V = randw(nb_heads, dim_v, dim_model)