- self.w_qw = randw(nb_heads, dim_qk, dim_in)
- self.w_qr = randw(nb_heads, dim_qk, dim_in)
- # self.w_k = randw(nb_heads, dim_qk, dim_in)
- self.w_v = randw(nb_heads, dim_v, dim_in)
- self.w_o = randw(dim_v * nb_heads, dim_in)
+ self.w_qw = randw(nb_heads, dim_qk, dim_model)
+ self.w_qr = randw(nb_heads, dim_qk, dim_model)
+ # self.w_k = randw(nb_heads, dim_qk, dim_model)
+ self.w_v = randw(nb_heads, dim_v, dim_model)
+ self.w_o = randw(dim_v * nb_heads, dim_model)