X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=mygpt.py;h=ab16e1e184bdc513f5cd189112cf78908d76abd0;hb=dfeb9072208095669528fc5ae2dedf78f089d9ad;hp=7f0c9e6b7bdd89a77de30ad5cfa9f47d1c8b6257;hpb=fc570d4ccd5d5dee36271d34ff5c672a50a82101;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index 7f0c9e6..ab16e1e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -57,7 +57,7 @@ class QKVAttention(nn.Module): self.w_q = randw(nb_heads, dim_qk, dim_in) self.w_k = randw(nb_heads, dim_qk, dim_in) self.w_v = randw(nb_heads, dim_v, dim_in) - self.w_o = randw(dim_in, dim_v * nb_heads) + self.w_o = randw(dim_v * nb_heads, dim_in) def forward(self, x_q, x_kv = None): if x_kv is None: x_kv = x_q @@ -111,7 +111,6 @@ class MyGPT(nn.Module): nb_heads = nb_heads, causal = True, attention_dropout = dropout ), - nn.Linear(in_features = dim_model, out_features = dim_model), ), Residual( nn.LayerNorm(dim_model), @@ -127,7 +126,7 @@ class MyGPT(nn.Module): self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size) def forward(self, x): - x = torch.cat((x.new_zeros(x.size(0), 1), x), 1) + x = F.pad(x, (1, 0)) x = self.embedding(x) x = self.trunk(x) x = self.readout(x) @@ -143,7 +142,7 @@ if __name__ == '__main__': model = MyGPT( vocabulary_size = vocabulary_size, - dim_model = 16, dim_keys = 50, dim_hidden = 100, + dim_model = 18, dim_keys = 50, dim_hidden = 100, nb_heads = 2, nb_blocks = 3, dropout = 0.1 )