X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=75adbf657b57039d48244297cf9b373916b22e5c;hb=539b475100e792e284d030e2a0b4bdb41c0ff780;hp=311ff6bf4dd39c35f1d9182c8db246f1f375f5f1;hpb=2cd3f15987d2bf9050f737cd13506740ad3e90cb;p=beaver.git diff --git a/mygpt.py b/mygpt.py index 311ff6b..75adbf6 100755 --- a/mygpt.py +++ b/mygpt.py @@ -106,20 +106,16 @@ class AddPositionalEncoding(nn.Module): ) order_output = order + 1 - order_input = torch.cat( - (order.new_zeros(order.size(0), 1), order[:, :-1] + 1), 1 - ) + order_input = F.pad(order + 1, (1, -1)) - self.pe = torch.cat( - ( - pe.gather(1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))), - pe.gather( - 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1)) - ), - ), - 2, + pe_input = pe.gather( + 1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1)) + ) + pe_output = pe.gather( + 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1)) ) + self.pe = torch.cat((pe_input, pe_output), 2) self.cache_y = bs.x.new(bs.x.size()) self.cache_y[:, bs.first : bs.first + bs.nb] = (