-def shuffle(x, order, reorder=False):
- if x.dim() == 3:
- order = order.unsqueeze(-1).expand(-1, -1, x.size(-1))
- if reorder:
- y = x.new(x.size())
- y.scatter_(1, order, x)
- return y
+def reorder(x, order, back=False): # x is NxTxD1x...xDk, order is NxT'
+ u = x.reshape(x.size()[:2] + (-1,))
+ order = order.unsqueeze(-1).expand(-1, -1, u.size(-1))
+ if back:
+ v = u.new(u.size())
+ v.scatter_(1, order, u)