- # current iteration, sample the next state
- s = -(torch.rand(result.size(0), device=result.device) < 0.2).long()
+ # current iteration, with a proba that depends with the
+ # sequence index, so that we have diverse examples, sample
+ # the next state
+ s = -(
+ torch.rand(result.size(0), device=result.device)
+ <= torch.linspace(0, 1, result.size(0), device=result.device)
+ ).long()