result = input.clone()
filler, space = self.char2id["#"], self.char2id[" "]
ar_mask = (result == space).long().cumsum(dim=1).clamp(max=1)
- result = (1 - ar_mask) * result + filler * ar_mask
+ result = (1 - ar_mask) * result + ar_mask * filler
masked_inplace_autoregression(
model, self.batch_size, result, ar_mask, device=self.device
)
result = input.clone()
filler, space = self.char2id["#"], self.char2id[" "]
ar_mask = (result == space).long().cumsum(dim=1).clamp(max=1)
- result = (1 - ar_mask) * result + filler * ar_mask
+ result = (1 - ar_mask) * result + ar_mask * filler
for n in range(result.size(0)):
s = "".join([self.id2char[k.item()] for k in result[n]])
log_string(f"test_before {s}")
masked_inplace_autoregression(
model, self.batch_size, result, ar_mask, device=self.device
)
+ correct = (1 - ar_mask) * space + ar_mask * input
for n in range(result.size(0)):
s = "".join([self.id2char[k.item()] for k in result[n]])
log_string(f"test_after {s}")
+ s = "".join([self.id2char[k.item()] for k in correct[n]])
+ log_string(f"correct {s}")
##############################################################
model.train(t)