From: François Fleuret Date: Mon, 5 Aug 2024 05:33:44 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=a489b3662063ad37bf04bbc52cb2526dfd135c1f;p=culture.git Update. --- diff --git a/mygpt.py b/mygpt.py index 812a139..4ffaa3d 100755 --- a/mygpt.py +++ b/mygpt.py @@ -17,6 +17,36 @@ import torch from torch import nn from torch.nn import functional as F +###################################################################### + +# +# This function gets a NxT tensor of long that encodes the group id of +# each token, and returns a NxT tensor sigma of long such that for any +# n sigma[n, :] is a permutation of {0...T-1} sampled uniformly among +# the permutations that verify +# +# for any n, i, j: group[n,i] < group[n,j] => sigma[n,i] < sigma[n,j] +# +# For instance +# +# block_sigma(torch.tensor([[2, 2, 0, 0, 0, 1, 1, 1, 1, 2]])) +# +# could be +# +# tensor([[8, 7, 1, 0, 2, 5, 4, 3, 6, 9]]) +# + + +def block_sigma(groups): + g = (groups[:, None, :] == torch.arange(groups.max() + 1)[None, :, None]).long() + r = g * torch.rand(g.size()) + (1 - g) * 2 + a = torch.arange(r.size(2)).repeat(r.size(0), r.size(1), 1) + s = a.new(r.size()).scatter_(dim=2, index=r.argsort(dim=2), src=a) * g + m = g.sum(dim=2).cumsum(dim=1) + s[:, 1:, :] += m[:, :-1, None] + return (s * g).sum(dim=1) + + ###################################################################### # A BracketedSequence is a BxTx... tensor with a first and a nb time