From b6a9cc237cdadac2351814f92c20607d46b0f583 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Fri, 29 Jul 2022 10:07:59 +0200 Subject: [PATCH] OCDC --- mygpt.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mygpt.py b/mygpt.py index 9da2e68..954f4f0 100755 --- a/mygpt.py +++ b/mygpt.py @@ -24,14 +24,12 @@ class WithResidual(nn.Module): ############################## -class PositionalEncoding(nn.Module): +class AddPositionalEncoding(nn.Module): def __init__(self, len_max): super().__init__() self.len_max = len_max - # From Vaswani et al 2018 - # PE_{t,2i} = sin(t/(L^{2i/D})) - # PE_{t,2i+1} = cos(t/(L^{2i/D})) + # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) def forward(self, x): t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None] j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :] @@ -96,7 +94,7 @@ class MyGPT(nn.Module): self.embedding = nn.Sequential( nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout), - PositionalEncoding(len_max), + AddPositionalEncoding(len_max), ) trunk_blocks = [ ] -- 2.39.5