+ def wipe_lookahead_rewards(self, batch):
+ t = torch.arange(batch.size(1), device=batch.device)[None, :]
+ u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device)
+ lr_mask = (t <= u).long() * (
+ t % self.it_len == self.index_lookahead_reward
+ ).long()
+
+ return lr_mask * greed.lookahead_reward2code(2) + (1 - lr_mask) * batch
+