)
if save_attention_image is not None:
- input = self.test_input[:1]
- result = input.clone()
- s = (result == self.t_prog).long()
- ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
- result = (1 - ar_mask) * result + ar_mask * self.t_nul
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
+ input = self.test_input[:1].clone()
+ last = (input != self.t_nul).max(0).values.nonzero().max() + 3
+ input = input[:, :last]
with torch.autograd.no_grad():
t = model.training
model.eval()
model.record_attention(True)
- model(BracketedSequence(result))
+ model(BracketedSequence(input))
model.train(t)
ram = model.retrieve_attention()
model.record_attention(False)
- tokens_output = [self.id2token[i.item()] for i in result[0]]
+ tokens_output = [self.id2token[i.item()] for i in input[0]]
tokens_input = ["n/a"] + tokens_output[:-1]
for n_head in range(ram[0].size(1)):
filename = os.path.join(