Fixed stuff.
[mygpt.git] / mygpt.py
index ab16e1e..7c4e06d 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -36,7 +36,8 @@ class PositionalEncoding(nn.Module):
         t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
         j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
         k = j%2
-        return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]
+        pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)
+        return x + pe # Let broadcasting to its job
 
 ##############################
 
@@ -107,7 +108,8 @@ class MyGPT(nn.Module):
                     nn.LayerNorm(dim_model),
                     QKVAttention(
                         dim_in = dim_model,
-                        dim_qk = dim_keys, dim_v = dim_model // nb_heads,
+                        dim_qk = dim_keys,
+                        dim_v = dim_model // nb_heads,
                         nb_heads = nb_heads,
                         causal = True, attention_dropout = dropout
                     ),