def wipe_lookahead_rewards(self, batch):
t = torch.arange(batch.size(1), device=batch.device)[None, :]
u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device)
lr_mask = (t <= u).long() * (
def wipe_lookahead_rewards(self, batch):
t = torch.arange(batch.size(1), device=batch.device)[None, :]
u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device)
lr_mask = (t <= u).long() * (