X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;fp=mygpt.py;h=7c4e06df1259aa4bd372987589ababf5f5afd6b1;hb=cd1cc80f711ca1f7188cc9854f18231e02470eba;hp=43711b3da39db0c7709b0613df0e3bd6a5f5153c;hpb=1ad9ea3cca4489b07bad8521966382f66a493eea;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index 43711b3..7c4e06d 100755 --- a/mygpt.py +++ b/mygpt.py @@ -36,7 +36,8 @@ class PositionalEncoding(nn.Module): t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None] j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :] k = j%2 - return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :] + pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k) + return x + pe # Let broadcasting to its job ##############################