######################################################################
- def sigma_for_grids(self, input):
+ def sigma_for_grids(self, input, block_order=(0, 1, 2, 3)):
l = input.size(1) // 4
sigma = input.new(input.size())
r = sigma.view(sigma.size(0), 4, l)
- r[:, 0] = 0 * l
- r[:, 1] = 1 * l
- r[:, 2] = 2 * l
- r[:, 3] = 3 * l
+ r[:, 0, :] = block_order[0] * l
+ r[:, 1, :] = block_order[1] * l
+ r[:, 2, :] = block_order[2] * l
+ r[:, 3, :] = block_order[3] * l
r[:, :, 1:] += (
torch.rand(input.size(0), 4, l - 1, device=input.device).sort(dim=2).indices
) + 1