- v = (
- (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
- + torch.randint(
- input.size(1) - memex_len, (input.size(0), 1), device=t.device
- )
- ) * memex_mask
- assert v.min() >= 0
- assert v.max() < input.size(1)
- u = u * (1 - memex_mask) + v * memex_mask
-
- new_input = input[n, u]
- assert input.max() < vocabulary_size
- assert new_input.max() < vocabulary_size
- limits = trigger.clone()
- limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
- assert limits.min() == 0
- assert limits.max() == 1
- new_input = new_input * (1 - limits) + marker_token * limits
- assert marker_token < vocabulary_size
- assert new_input.max() < vocabulary_size