From: François Fleuret Date: Sat, 22 Jul 2023 18:39:58 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=459640c9228bf0ee4b817bfaeb1863ea4fce45b3;p=picoclvr.git Update. --- diff --git a/graph.py b/graph.py index 5195cc9..3c92e8d 100755 --- a/graph.py +++ b/graph.py @@ -23,7 +23,12 @@ def save_attention_image( layer_gap=25, y_eps=0.5, padding=10, + # do not draw links with a lesser attention min_att=0, + # draw only the strongest links necessary to have less than + # residual remaining + residual=None, + # draw only the top k links k_top=None, ): attention = torch.cat( @@ -35,6 +40,12 @@ def save_attention_image( attention.sort(dim=-1, descending=True).indices < k_top ) + if residual is not None: + s = attention.sort(dim=-1) + m = 1 - (s.values.cumsum(-1) < residual).long() + b = m.new(attention.size()).scatter_(dim=-1, index=s.indices, src=m) + attention = attention * b + surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None) ctx = cairo.Context(surface)