X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=mygpt.py;h=121ad800fb41a631642ac64a44b2b1acfa9fbac2;hb=a1a7cb9e680378db521f2a1e2139db0e2db903de;hp=42960b13ce438d7002c20783676debc3aab5793e;hpb=9c62741f73c7bbcd00bafad84cd31325b358ef1d;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index 42960b1..121ad80 100755 --- a/mygpt.py +++ b/mygpt.py @@ -41,34 +41,43 @@ class PositionalEncoding(nn.Module): ############################## class QKVAttention(nn.Module): - def __init__(self, dim_in, dim_qk, dim_v, - nb_heads = 1, causal = False, attention_dropout = 0.0): + def __init__( + self, + dim_in, dim_qk, dim_v, + nb_heads = 1, causal = False, attention_dropout = 0.0 + ): super().__init__() def randw(*d): - return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1]))) + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.causal = causal + self.attention_dropout = attention_dropout self.w_q = randw(nb_heads, dim_qk, dim_in) self.w_k = randw(nb_heads, dim_qk, dim_in) self.w_v = randw(nb_heads, dim_v, dim_in) - self.w_o = randw(nb_heads, dim_in, dim_v) - self.causal = causal - self.attention_dropout = attention_dropout + self.w_o = randw(dim_in, dim_v * nb_heads) def forward(self, x_q, x_kv = None): if x_kv is None: x_kv = x_q + q = torch.einsum('ntc,hdc->nhtd', x_q, self.w_q) k = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_k) v = torch.einsum('ntc,hdc->nhtd', x_kv, self.w_v) + a = torch.einsum('nhtd,nhsd->nhts', q, k) / math.sqrt(q.size(3)) + if self.causal: mask = torch.arange(a.size(2), device = q.device)[None, None, :, None] \ < torch.arange(a.size(3), device = q.device)[None, None, None, :] a = a.masked_fill(mask, float('-inf')) + a = a.softmax(dim = 3) a = F.dropout(a, self.attention_dropout, self.training) - y = torch.einsum('nhts,nhsd->nhtd', a, v) - y = torch.einsum('nhtd,hcd->ntc', y, self.w_o) + y = torch.einsum('nhts,nhsd->nthd', a, v).flatten(2) + + y = y @ self.w_o return y @@ -102,7 +111,6 @@ class MyGPT(nn.Module): nb_heads = nb_heads, causal = True, attention_dropout = dropout ), - nn.Linear(in_features = dim_model, out_features = dim_model), ), Residual( nn.LayerNorm(dim_model), @@ -118,10 +126,11 @@ class MyGPT(nn.Module): self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size) def forward(self, x): + x = F.pad(x, (1, 0)) x = self.embedding(x) x = self.trunk(x) x = self.readout(x) - return x + return x[:, :-1] ######################################################################