X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=57cbbc6b23d66404e374851dedf4a35fa4c85852;hb=ceda7771b579aa3fb21115c6e71975d3cb7583bd;hp=42960b13ce438d7002c20783676debc3aab5793e;hpb=02a6cbfdda55ba2ce13ff0f925009c15cc2a0a90;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index 42960b1..57cbbc6 100755 --- a/mygpt.py +++ b/mygpt.py @@ -51,7 +51,7 @@ class QKVAttention(nn.Module): self.w_q = randw(nb_heads, dim_qk, dim_in) self.w_k = randw(nb_heads, dim_qk, dim_in) self.w_v = randw(nb_heads, dim_v, dim_in) - self.w_o = randw(nb_heads, dim_in, dim_v) + self.w_o = randw(dim_in, dim_v * nb_heads) self.causal = causal self.attention_dropout = attention_dropout @@ -67,8 +67,8 @@ class QKVAttention(nn.Module): a = a.masked_fill(mask, float('-inf')) a = a.softmax(dim = 3) a = F.dropout(a, self.attention_dropout, self.training) - y = torch.einsum('nhts,nhsd->nhtd', a, v) - y = torch.einsum('nhtd,hcd->ntc', y, self.w_o) + y = torch.einsum('nhts,nhsd->nthd', a, v) + y = y.flatten(2) @ self.w_o return y