######################################################################
-def generation_order(x, prompt_len=0):
- if args.random_regression_order:
- order = torch.rand(x.size(), device=x.device)
- order[:, :prompt_len] = torch.arange(-prompt_len, 0, device=x.device)
- order = order.sort(1).indices
- else:
- order = (
- torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
- )
- return order
-
-
def reorder(x, order, reverse=False): # x is NxTxD1x...xDk, order is NxT'
u = x.reshape(x.size()[:2] + (-1,))
order = order.unsqueeze(-1).expand(-1, -1, u.size(-1))
def shuffle(x, prompt_len):
- order = generation_order(x, prompt_len)
+ if args.random_regression_order:
+ order = torch.rand(x.size(), device=x.device)
+ order[:, :prompt_len] = torch.arange(-prompt_len, 0, device=x.device)
+ order = order.sort(1).indices
+ else:
+ order = (
+ torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
+ )
return reorder(x, order), order