X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=7c4e06df1259aa4bd372987589ababf5f5afd6b1;hb=f031f7f87b5be907df081395023a9acba8ba9c7c;hp=ab16e1e184bdc513f5cd189112cf78908d76abd0;hpb=12184f604b37f36f07d7dcdd567b1c78f02c74db;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index ab16e1e..7c4e06d 100755 --- a/mygpt.py +++ b/mygpt.py @@ -36,7 +36,8 @@ class PositionalEncoding(nn.Module): t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None] j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :] k = j%2 - return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :] + pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k) + return x + pe # Let broadcasting to its job ############################## @@ -107,7 +108,8 @@ class MyGPT(nn.Module): nn.LayerNorm(dim_model), QKVAttention( dim_in = dim_model, - dim_qk = dim_keys, dim_v = dim_model // nb_heads, + dim_qk = dim_keys, + dim_v = dim_model // nb_heads, nb_heads = nb_heads, causal = True, attention_dropout = dropout ),