- # Re-generate the lookahead_reward pessimistically in the
- # previous iterations
- ar_mask = (t < u).long() * (t % it_len == index_lookahead_reward).long()
- ar(result, ar_mask, logit_biases=-optimistic_bias)
- snapshots.append(result[:10].detach().clone())
-
- # Generate the state
- ar_mask = (t >= u).long() * (t < u + state_len).long()
+ # Generate the lookahead_reward and state
+ ar_mask = (t >= u + index_lookahead_reward).long() * (
+ t < u + index_states + state_len
+ ).long()