-def random_order(result, fixed_len):
- order = torch.rand(result.size(), device=result.device)
- order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device)
- return order.sort(1).indices
+def generation_order(x, fixed_len):
+ if args.random_regression_order:
+ order = torch.rand(x.size(), device=x.device)
+ order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=x.device)
+ order = order.sort(1).indices
+ else:
+ order = (
+ torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
+ )
+ return order