for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"):
input = input.to(local_device)
- output = model(mygpt.BracketedSequence(input)).x
+ sigma = torch.rand(input.size(), device=input.device).sort(dim=1).indices
+ output = model(mygpt.BracketedSequence(input), sigma).x
loss = F.cross_entropy(output.transpose(1, 2), input)
acc_test_loss += loss.item() * input.size(0)
nb_test_samples += input.size(0)
targets = input
- output = model(mygpt.BracketedSequence(input)).x
+ sigma = torch.rand(input.size(), device=input.device).sort(dim=1).indices
+ output = model(mygpt.BracketedSequence(input), sigma).x
loss_per_token = F.cross_entropy(
output.transpose(1, 2), targets, reduction="none"
)
# [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
- def forward(self, bs):
+ def forward(self, bs, sigma=None):
if bs.first == 0:
- t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
- :, None
- ]
- j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
- None, :
- ]
- k = j % 2
- self.pe = torch.sin(
- t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
- )
+ if sigma is None:
+ t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
+ None, :, None
+ ]
+ j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
+ None, None, :
+ ]
+ k = j % 2
+ self.pe = torch.sin(
+ t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+ )
+ else:
+ t_out = sigma[:, :, None]
+ t_in = F.pad(t_out, (0, 0, 1, -1), value=-1)
+ j = torch.arange(
+ bs.x.size(2) // 2, dtype=bs.x.dtype, device=bs.x.device
+ )[None, None, :]
+ k = j % 2
+ pe_out = torch.sin(
+ t_out / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+ )
+ pe_in = torch.sin(
+ t_in / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
+ )
+ self.pe = torch.cat([pe_in, pe_out], dim=2)
+
self.cache_y = bs.x.new(bs.x.size())
self.cache_y[:, bs.first : bs.first + bs.nb] = (
- bs.slice() + self.pe[bs.first : bs.first + bs.nb]
+ bs.slice() + self.pe[:, bs.first : bs.first + bs.nb]
)
return BracketedSequence(self.cache_y, bs.first, bs.nb)
self.temperature = 1.0
- self.embedding = nn.Sequential(
- CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
- AddPositionalEncoding(len_max),
+ self.embedding = CacheWrapper(
+ nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)
)
+ self.positional_encoding = AddPositionalEncoding(len_max)
+
trunk_blocks = []
for b in range(nb_blocks):
m.bias.zero_()
m.weight.fill_(1.0)
- def forward(self, bs):
+ def forward(self, bs, sigma=None):
+ if sigma is not None:
+ bs.x = bs.x.gather(dim=1, index=sigma)
bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
bs = self.embedding(bs)
+ bs = self.positional_encoding(bs, sigma)
bs = self.trunk(bs)
bs = self.readout(bs)
bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
+ if sigma is not None:
+ bs.x.scatter_(
+ dim=1, index=sigma[:, :, None].expand_as(bs.x), src=bs.x.clone()
+ )
return bs
def encode(self, bs):
def partial_forward(self, bs, start_layer=None, end_layer=None):
if start_layer is None:
- # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
bs = self.embedding(bs)
if end_layer is not None: