):
result = self.test_input[:100].clone()
t = torch.arange(result.size(1), device=result.device)
- itl = self.height * self.width + 3
+ state_len = self.height * self.width
+ iteration_len = state_len + 3
def ar():
masked_inplace_autoregression(
device=self.device,
)
- for u in range(itl, result.size(1) - itl + 1, itl):
- print(f"{itl=} {u=} {result.size(1)=}")
+ for u in range(
+ iteration_len, result.size(1) - iteration_len + 1, iteration_len
+ ):
+ # Put a lookahead reward to -1, sample the next state
result[:, u - 1] = (-1) + 1 + escape.first_lookahead_rewards_code
- ar_mask = (t >= u).long() * (t < u + self.height * self.width).long()
+ ar_mask = (t >= u).long() * (t < u + state_len).long()
ar_mask = ar_mask[None, :]
ar_mask = ar_mask.expand_as(result)
result *= 1 - ar_mask
ar()
+
+ # Put a lookahead reward to +1, sample the action and reward
result[:, u - 1] = (1) + 1 + escape.first_lookahead_rewards_code
- ar_mask = (t >= self.height * self.width).long() * (
- t < self.height * self.width + 2
- ).long()
+ ar_mask = (t >= state_len).long() * (t < state_len + 2).long()
ar_mask = ar_mask[None, :]
ar_mask = ar_mask.expand_as(result)
result *= 1 - ar_mask