Update.
[mygptrnn.git] / mygpt.py
index 7c9991f..099847c 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -493,14 +493,16 @@ class Caterpillar(nn.Module):
 
         self.proba_gate_dropout = 0.0
 
-        default_b_G = kwargs.get("default_b_G")
-        if default_b_G is None:
-            default_b_G = -math.log(caterpillar_height - 1)
+        default_bg = kwargs.get("default_bg")
+        if default_bg is None:
+            default_bg = -math.log(caterpillar_height - 1)
+        else:
+            default_bg = float(default_bg)
 
-        logger(f"default_b_G {default_b_G}")
+        logger(f"default_bg {default_bg}")
 
         self.w_G = randw(nb_heads, caterpillar_height, dim_model)
-        self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_b_G))
+        self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg))
 
         self.w_K = randw(nb_heads, dim_qk, dim_model)
         self.w_V = randw(nb_heads, dim_v, dim_model)