- for t in range(self.nb_time_steps-1):
- v = torch.rand(x.size()) * (x > 0).float()
- u = (v.max(dim=-1,keepdim=True).values == v).long()
- n = (u*x*torch.rand(x.size())).long().sum(dim=-1,keepdim=True) // 2
- x = x + n * (u.roll(shifts=-1,dims=-1) - 2 * u + u.roll(shifts=1,dims=-1))
+ for t in range(self.nb_time_steps - 1):
+ v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long()
+ u = (v.max(dim=-1, keepdim=True).values == v).long()
+ n = (
+ (u * x)
+ .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size()))
+ .sum(dim=-1, keepdim=True)
+ )
+ m = 1 + ((n - 1) * torch.rand(n.size())).long()
+ x = (
+ x
+ + m * u.roll(shifts=-1, dims=-1)
+ - n * u
+ + (n - m) * u.roll(shifts=1, dims=-1)
+ )