- input = self.test_input[:1]
- result = input.clone()
- s = (result == self.t_prog).long()
- ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
- result = (1 - ar_mask) * result + ar_mask * self.t_nul
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
+ input = self.test_input[:1].clone()
+ last = (input != self.t_nul).max(0).values.nonzero().max() + 3
+ input = input[:, :last].to(self.device)