+def random_order(result, fixed_len):
+ if args.random_regression_order:
+ order = torch.rand(result.size(), device=result.device)
+ order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device)
+ return order.sort(1).indices
+ else:
+ return torch.arange(result.size(1)).unsqueeze(0).expand(result.size(0), -1)
+
+
+def shuffle(x, order, reorder=False):
+ if x.dim() == 3:
+ order = order.unsqueeze(-1).expand(-1, -1, x.size(-1))
+ if reorder:
+ y = x.new(x.size())
+ y.scatter_(1, order, x)
+ return y
+ else:
+ return x.gather(1, order)
+
+