X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=311ff6bf4dd39c35f1d9182c8db246f1f375f5f1;hb=2cd3f15987d2bf9050f737cd13506740ad3e90cb;hp=a1db2e312d19ef7aba0d82f024f2618a32f471d7;hpb=e23374a06a07a1bc899e1c7ff7f5d8be75f9cdb5;p=beaver.git diff --git a/mygpt.py b/mygpt.py index a1db2e3..311ff6b 100755 --- a/mygpt.py +++ b/mygpt.py @@ -85,22 +85,45 @@ class AddPositionalEncoding(nn.Module): # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) - def forward(self, bs): + def forward(self, bs, order): # NxTxD, T if bs.first == 0: - t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[ - :, None - ] - j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[ + t = ( + torch.arange(bs.x.size(1) + 1, dtype=bs.x.dtype, device=bs.x.device)[ + :, None + ] + - 1 + ) + j = torch.arange(bs.x.size(2) // 2, dtype=bs.x.dtype, device=bs.x.device)[ None, : ] k = j % 2 - self.pe = torch.sin( - t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k + pe = ( + torch.sin( + t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k + ) + .unsqueeze(0) + .expand(bs.x.size(0), -1, -1) ) + + order_output = order + 1 + order_input = torch.cat( + (order.new_zeros(order.size(0), 1), order[:, :-1] + 1), 1 + ) + + self.pe = torch.cat( + ( + pe.gather(1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))), + pe.gather( + 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1)) + ), + ), + 2, + ) + self.cache_y = bs.x.new(bs.x.size()) self.cache_y[:, bs.first : bs.first + bs.nb] = ( - bs.slice() + self.pe[bs.first : bs.first + bs.nb] + bs.slice() + self.pe[:, bs.first : bs.first + bs.nb] ) bs.x = self.cache_y @@ -201,10 +224,10 @@ class MyGPT(nn.Module): assert dim_model % nb_heads == 0 - self.embedding = nn.Sequential( - CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), - AddPositionalEncoding(len_max), + self.embedding = CacheWrapper( + nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout) ) + self.pe = AddPositionalEncoding(len_max) trunk_blocks = [] @@ -246,22 +269,28 @@ class MyGPT(nn.Module): m.bias.zero_() m.weight.fill_(1.0) - def forward(self, bs, mode='standard'): - bs.x = F.pad(bs.x, (1, -1)) + def forward(self, bs, mode="standard", order=None): + bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) + if order is None: + order = torch.arange(bs.x.size(1), device=bs.x.device)[None, :].expand_as( + bs.x + ) bs = self.embedding(bs) - if mode=='standard': + bs = self.pe(bs, order) + + if mode == "standard": bs = self.trunk(bs) bs = self.readout(bs) - elif mode=='head': + elif mode == "head": bs = self.trunk(bs) - elif mode=='deep': + elif mode == "deep": r = [] for l in self.trunk: bs = l(bs) - r += [ bs.slice() ] + r += [bs.slice()] bs = BracketedSequence(torch.cat(r, -1)) else: - raise ValueError + raise ValueError(f"{mode=}") return bs