From 041605103d6529e5c03fc8ffa98a9a81a78842fb Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 21 Mar 2023 23:37:47 +0100 Subject: [PATCH] Update --- README.txt | 4 ++++ mygpt.py | 20 ++++++++++++++------ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/README.txt b/README.txt index dc13a4f..8265b48 100644 --- a/README.txt +++ b/README.txt @@ -2,3 +2,7 @@ To train the shortest-path solving GPT, and train the one-shot MLP read-out: ./beaver.py --oneshot + +Same, lighter settings (~95% test success instead of ~99%): + + ./beaver.py --nb_train_samples=25000 --nb_test_samples=10000 --nb_epochs=10 --oneshot diff --git a/mygpt.py b/mygpt.py index 0b63ac8..232b604 100755 --- a/mygpt.py +++ b/mygpt.py @@ -85,7 +85,7 @@ class AddPositionalEncoding(nn.Module): # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) - def forward(self, bs): + def forward(self, bs, order=None): if bs.first == 0: t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[ :, None @@ -97,6 +97,10 @@ class AddPositionalEncoding(nn.Module): self.pe = torch.sin( t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k ) + + if order is not None: + self.pe = self.pe.gather(1, order.unsqueeze(-1).expand_as(self.pe)) + self.cache_y = bs.x.new(bs.x.size()) self.cache_y[:, bs.first : bs.first + bs.nb] = ( @@ -201,10 +205,10 @@ class MyGPT(nn.Module): assert dim_model % nb_heads == 0 - self.embedding = nn.Sequential( - CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), - AddPositionalEncoding(len_max), + self.embedding = CacheWrapper( + nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout) ) + self.pe = AddPositionalEncoding(len_max) trunk_blocks = [] @@ -246,9 +250,13 @@ class MyGPT(nn.Module): m.bias.zero_() m.weight.fill_(1.0) - def forward(self, bs, mode="standard"): - bs.x = F.pad(bs.x, (1, -1)) + def forward(self, bs, mode="standard", order=None): + bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) + if order is not None: + order = F.pad(order + 1, (1, -1)) bs = self.embedding(bs) + bs = self.pe(bs, order) + if mode == "standard": bs = self.trunk(bs) bs = self.readout(bs) -- 2.20.1