- # Put the lookahead reward to either 0 or -1 for the
- # current iteration, sample the next state
- s = -(torch.rand(result.size(0), device=result.device) < 0.2).long()
- result[:, u - 1] = s + 1 + escape.first_lookahead_rewards_code
- ar_mask = (t >= u).long() * (t < u + state_len).long()
+ # Generate the next state but keep the initial one, the
+ # lookahead_reward of previous iterations are set to
+ # UNKNOWN
+ if u > 0:
+ result[
+ :, u + self.world.index_lookahead_reward
+ ] = self.world.lookahead_reward2code(2)
+ ar_mask = (t >= u + self.world.index_states).long() * (
+ t < u + self.world.index_states + self.world.state_len
+ ).long()
+ ar(result, ar_mask)
+
+ # Generate the action and reward with lookahead_reward to +1
+ result[
+ :, u + self.world.index_lookahead_reward
+ ] = self.world.lookahead_reward2code(1)
+ ar_mask = (t >= u + self.world.index_reward).long() * (
+ t <= u + self.world.index_action
+ ).long()