+def generation_order(x, prompt_len=0):
+ if args.random_regression_order:
+ order = torch.rand(x.size(), device=x.device)
+ order[:, :prompt_len] = torch.arange(-prompt_len, 0, device=x.device)
+ order = order.sort(1).indices
+ else:
+ order = (
+ torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
+ )
+ return order
+
+
+def reorder(x, order, reverse=False): # x is NxTxD1x...xDk, order is NxT'
+ u = x.reshape(x.size()[:2] + (-1,))
+ order = order.unsqueeze(-1).expand(-1, -1, u.size(-1))
+ if reverse:
+ v = u.new(u.size()).scatter_(1, order, u)
+ else:
+ v = u.gather(1, order)
+ v = v.reshape(v.size()[:2] + x.size()[2:])
+ return v
+
+
+def shuffle(x, prompt_len):
+ order = generation_order(x, prompt_len)
+ return reorder(x, order), order
+
+
+def eval_mygpt(model, input, mode="standard", prompt_len=0):
+ x, order = shuffle(input, prompt_len)
+ x = model(mygpt.BracketedSequence(x), mode=mode, order=order).x
+ return reorder(x, order, reverse=True)
+
+
+######################################################################
+