- input = self.test_input[:1]
- result = input.clone()
- s = (result == self.t_prog).long()
- ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
- result = (1 - ar_mask) * result + ar_mask * self.t_nul
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
+ ns=torch.randint(self.text_input.size(0),(1,)).item()
+ input = self.test_input[ns:ns+1].clone()
+ last = (input != self.t_nul).max(0).values.nonzero().max() + 3
+ input = input[:, :last].to(self.device)