X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=graph.py;h=3c92e8d2c12d9027ccc9dba87c572992eaf19e89;hb=459640c9228bf0ee4b817bfaeb1863ea4fce45b3;hp=5bab86183b0bd252f2e491a47e59b09a139711d2;hpb=ef3bef5253ff719953dfffff28d4122c19acdd77;p=picoclvr.git diff --git a/graph.py b/graph.py index 5bab861..3c92e8d 100755 --- a/graph.py +++ b/graph.py @@ -23,7 +23,12 @@ def save_attention_image( layer_gap=25, y_eps=0.5, padding=10, + # do not draw links with a lesser attention min_att=0, + # draw only the strongest links necessary to have less than + # residual remaining + residual=None, + # draw only the top k links k_top=None, ): attention = torch.cat( @@ -35,6 +40,12 @@ def save_attention_image( attention.sort(dim=-1, descending=True).indices < k_top ) + if residual is not None: + s = attention.sort(dim=-1) + m = 1 - (s.values.cumsum(-1) < residual).long() + b = m.new(attention.size()).scatter_(dim=-1, index=s.indices, src=m) + attention = attention * b + surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None) ctx = cairo.Context(surface) @@ -47,12 +58,6 @@ def save_attention_image( x, y = 0, 0 for d in range(attention.size(0)): - if d > 0: - for n in range(attention.size(-1)): - xc, yc = n * token_gap, -d * layer_gap - ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) - ctx.fill() - at = attention[d] ni = torch.arange(at.size(0))[:, None].expand_as(at) nj = torch.arange(at.size(1))[None, :].expand_as(at) @@ -71,14 +76,14 @@ def save_attention_image( ctx.stroke() y -= layer_gap - for d in range(1, attention.size(0)): + for d in range(0, attention.size(0) + 1): for n in range(attention.size(-1)): xc, yc = n * token_gap, -d * layer_gap ctx.set_source_rgb(1.0, 1.0, 1.0) - ctx.arc(xc, yc, token_gap / 10 + 0.5, 0, 2 * math.pi) + ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) ctx.fill() ctx.set_source_rgb(0.0, 0.0, 0.0) - ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) + ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi) ctx.fill() ctx.set_source_rgb(0.0, 0.0, 0.0) @@ -93,7 +98,7 @@ def save_attention_image( x_advance, y_advance, ) = ctx.text_extents(s) - ctx.move_to(k * token_gap - width_t / 2, -y_bearing) + ctx.move_to(k * token_gap - width_t / 2, token_gap / 5 - y_bearing) ctx.show_text(s) for k, t in enumerate(tokens_output): @@ -106,7 +111,9 @@ def save_attention_image( x_advance, y_advance, ) = ctx.text_extents(s) - ctx.move_to(k * token_gap - width_t / 2, -attention.size(0) * layer_gap) + ctx.move_to( + k * token_gap - width_t / 2, -token_gap / 5 - attention.size(0) * layer_gap + ) ctx.show_text(s) x, y, width, height = surface.ink_extents()