From: Francois Fleuret Date: Fri, 29 Jul 2022 08:07:59 +0000 (+0200) Subject: OCDC X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=b6a9cc237cdadac2351814f92c20607d46b0f583;p=mygpt.git OCDC --- diff --git a/mygpt.py b/mygpt.py index 9da2e68..954f4f0 100755 --- a/mygpt.py +++ b/mygpt.py @@ -24,14 +24,12 @@ class WithResidual(nn.Module): ############################## -class PositionalEncoding(nn.Module): +class AddPositionalEncoding(nn.Module): def __init__(self, len_max): super().__init__() self.len_max = len_max - # From Vaswani et al 2018 - # PE_{t,2i} = sin(t/(L^{2i/D})) - # PE_{t,2i+1} = cos(t/(L^{2i/D})) + # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) def forward(self, x): t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None] j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :] @@ -96,7 +94,7 @@ class MyGPT(nn.Module): self.embedding = nn.Sequential( nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout), - PositionalEncoding(len_max), + AddPositionalEncoding(len_max), ) trunk_blocks = [ ]