from torch import nn
from torch.nn import functional as F
+######################################################################
+
+#
+# This function gets a NxT tensor of long that encodes the group id of
+# each token, and returns a NxT tensor sigma of long such that for any
+# n sigma[n, :] is a permutation of {0...T-1} sampled uniformly among
+# the permutations that verify
+#
+# for any n, i, j: group[n,i] < group[n,j] => sigma[n,i] < sigma[n,j]
+#
+# For instance
+#
+# block_sigma(torch.tensor([[2, 2, 0, 0, 0, 1, 1, 1, 1, 2]]))
+#
+# could be
+#
+# tensor([[8, 7, 1, 0, 2, 5, 4, 3, 6, 9]])
+#
+
+
+def block_sigma(groups):
+ g = (groups[:, None, :] == torch.arange(groups.max() + 1)[None, :, None]).long()
+ r = g * torch.rand(g.size()) + (1 - g) * 2
+ a = torch.arange(r.size(2)).repeat(r.size(0), r.size(1), 1)
+ s = a.new(r.size()).scatter_(dim=2, index=r.argsort(dim=2), src=a) * g
+ m = g.sum(dim=2).cumsum(dim=1)
+ s[:, 1:, :] += m[:, :-1, None]
+ return (s * g).sum(dim=1)
+
+
######################################################################
# A BracketedSequence is a BxTx... tensor with a first and a nb time