From 9c4325b877ede05e14699ddae211d1edc83c1515 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 27 Mar 2023 15:38:16 +0200 Subject: [PATCH] Update --- beaver.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/beaver.py b/beaver.py index 5407859..074e137 100755 --- a/beaver.py +++ b/beaver.py @@ -133,18 +133,6 @@ for n in vars(args): ###################################################################### -def generation_order(x, prompt_len=0): - if args.random_regression_order: - order = torch.rand(x.size(), device=x.device) - order[:, :prompt_len] = torch.arange(-prompt_len, 0, device=x.device) - order = order.sort(1).indices - else: - order = ( - torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1) - ) - return order - - def reorder(x, order, reverse=False): # x is NxTxD1x...xDk, order is NxT' u = x.reshape(x.size()[:2] + (-1,)) order = order.unsqueeze(-1).expand(-1, -1, u.size(-1)) @@ -157,7 +145,14 @@ def reorder(x, order, reverse=False): # x is NxTxD1x...xDk, order is NxT' def shuffle(x, prompt_len): - order = generation_order(x, prompt_len) + if args.random_regression_order: + order = torch.rand(x.size(), device=x.device) + order[:, :prompt_len] = torch.arange(-prompt_len, 0, device=x.device) + order = order.sort(1).indices + else: + order = ( + torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1) + ) return reorder(x, order), order -- 2.39.5