X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;fp=mygpt.py;h=099847c95d9404d477b069d8cdf78a62304b3784;hb=2434c00a82ebb0b23f45d891cc9f80324e3200bd;hp=7c9991f7d56fc3069a261b0642e4381c55bd02d9;hpb=73acbc986f9c386c001117581c4fc72d2f36803a;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index 7c9991f..099847c 100755 --- a/mygpt.py +++ b/mygpt.py @@ -493,14 +493,16 @@ class Caterpillar(nn.Module): self.proba_gate_dropout = 0.0 - default_b_G = kwargs.get("default_b_G") - if default_b_G is None: - default_b_G = -math.log(caterpillar_height - 1) + default_bg = kwargs.get("default_bg") + if default_bg is None: + default_bg = -math.log(caterpillar_height - 1) + else: + default_bg = float(default_bg) - logger(f"default_b_G {default_b_G}") + logger(f"default_bg {default_bg}") self.w_G = randw(nb_heads, caterpillar_height, dim_model) - self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_b_G)) + self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg)) self.w_K = randw(nb_heads, dim_qk, dim_model) self.w_V = randw(nb_heads, dim_v, dim_model)