projects
/
beaver.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
917e010
)
Update
author
François Fleuret
<francois@fleuret.org>
Mon, 27 Mar 2023 13:38:16 +0000
(15:38 +0200)
committer
François Fleuret
<francois@fleuret.org>
Mon, 27 Mar 2023 13:38:16 +0000
(15:38 +0200)
beaver.py
patch
|
blob
|
history
diff --git
a/beaver.py
b/beaver.py
index
5407859
..
074e137
100755
(executable)
--- a/
beaver.py
+++ b/
beaver.py
@@
-133,18
+133,6
@@
for n in vars(args):
######################################################################
######################################################################
-def generation_order(x, prompt_len=0):
- if args.random_regression_order:
- order = torch.rand(x.size(), device=x.device)
- order[:, :prompt_len] = torch.arange(-prompt_len, 0, device=x.device)
- order = order.sort(1).indices
- else:
- order = (
- torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
- )
- return order
-
-
def reorder(x, order, reverse=False): # x is NxTxD1x...xDk, order is NxT'
u = x.reshape(x.size()[:2] + (-1,))
order = order.unsqueeze(-1).expand(-1, -1, u.size(-1))
def reorder(x, order, reverse=False): # x is NxTxD1x...xDk, order is NxT'
u = x.reshape(x.size()[:2] + (-1,))
order = order.unsqueeze(-1).expand(-1, -1, u.size(-1))
@@
-157,7
+145,14
@@
def reorder(x, order, reverse=False): # x is NxTxD1x...xDk, order is NxT'
def shuffle(x, prompt_len):
def shuffle(x, prompt_len):
- order = generation_order(x, prompt_len)
+ if args.random_regression_order:
+ order = torch.rand(x.size(), device=x.device)
+ order[:, :prompt_len] = torch.arange(-prompt_len, 0, device=x.device)
+ order = order.sort(1).indices
+ else:
+ order = (
+ torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
+ )
return reorder(x, order), order
return reorder(x, order), order