X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;fp=mygpt.py;h=ab16e1e184bdc513f5cd189112cf78908d76abd0;hb=12184f604b37f36f07d7dcdd567b1c78f02c74db;hp=121ad800fb41a631642ac64a44b2b1acfa9fbac2;hpb=a1a7cb9e680378db521f2a1e2139db0e2db903de;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index 121ad80..ab16e1e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -57,7 +57,7 @@ class QKVAttention(nn.Module): self.w_q = randw(nb_heads, dim_qk, dim_in) self.w_k = randw(nb_heads, dim_qk, dim_in) self.w_v = randw(nb_heads, dim_v, dim_in) - self.w_o = randw(dim_in, dim_v * nb_heads) + self.w_o = randw(dim_v * nb_heads, dim_in) def forward(self, x_q, x_kv = None): if x_kv is None: x_kv = x_q @@ -142,7 +142,7 @@ if __name__ == '__main__': model = MyGPT( vocabulary_size = vocabulary_size, - dim_model = 16, dim_keys = 50, dim_hidden = 100, + dim_model = 18, dim_keys = 50, dim_hidden = 100, nb_heads = 2, nb_blocks = 3, dropout = 0.1 )