projects
/
mygptrnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygptrnn.git]
/
mygpt.py
diff --git
a/mygpt.py
b/mygpt.py
index
7c9991f
..
099847c
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-493,14
+493,16
@@
class Caterpillar(nn.Module):
self.proba_gate_dropout = 0.0
self.proba_gate_dropout = 0.0
- default_b_G = kwargs.get("default_b_G")
- if default_b_G is None:
- default_b_G = -math.log(caterpillar_height - 1)
+ default_bg = kwargs.get("default_bg")
+ if default_bg is None:
+ default_bg = -math.log(caterpillar_height - 1)
+ else:
+ default_bg = float(default_bg)
- logger(f"default_b
_G {default_b_G
}")
+ logger(f"default_b
g {default_bg
}")
self.w_G = randw(nb_heads, caterpillar_height, dim_model)
self.w_G = randw(nb_heads, caterpillar_height, dim_model)
- self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_b
_G
))
+ self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_b
g
))
self.w_K = randw(nb_heads, dim_qk, dim_model)
self.w_V = randw(nb_heads, dim_v, dim_model)
self.w_K = randw(nb_heads, dim_qk, dim_model)
self.w_V = randw(nb_heads, dim_v, dim_model)