X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=graph.py;h=3c92e8d2c12d9027ccc9dba87c572992eaf19e89;hb=459640c9228bf0ee4b817bfaeb1863ea4fce45b3;hp=5195cc9a0a50fd8301addc29c9e33de7cb7c8b26;hpb=9b9d7bc878171bd65b0c8a803494a2e4ef00c5fe;p=picoclvr.git diff --git a/graph.py b/graph.py index 5195cc9..3c92e8d 100755 --- a/graph.py +++ b/graph.py @@ -23,7 +23,12 @@ def save_attention_image( layer_gap=25, y_eps=0.5, padding=10, + # do not draw links with a lesser attention min_att=0, + # draw only the strongest links necessary to have less than + # residual remaining + residual=None, + # draw only the top k links k_top=None, ): attention = torch.cat( @@ -35,6 +40,12 @@ def save_attention_image( attention.sort(dim=-1, descending=True).indices < k_top ) + if residual is not None: + s = attention.sort(dim=-1) + m = 1 - (s.values.cumsum(-1) < residual).long() + b = m.new(attention.size()).scatter_(dim=-1, index=s.indices, src=m) + attention = attention * b + surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None) ctx = cairo.Context(surface)