- self.state_len = self.height * self.width
- self.index_lookahead_reward = 0
- self.index_states = 1
- self.index_action = self.state_len + 1
- self.index_reward = self.state_len + 2
- self.it_len = self.state_len + 3 # lookahead_reward / state / action / reward
+ def wipe_lookahead_rewards(self, batch):
+ t = torch.arange(batch.size(1), device=batch.device)[None, :]
+ u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device)
+ lr_mask = (t <= u).long() * (
+ t % self.world.it_len == self.world.index_lookahead_reward
+ ).long()
+
+ return lr_mask * self.world.lookahead_reward2code(2) + (1 - lr_mask) * batch