layer_gap=25,
y_eps=0.5,
padding=10,
+ # do not draw links with a lesser attention
min_att=0,
+ # draw only the strongest links necessary to have less than
+ # residual remaining
+ residual=None,
+ # draw only the top k links
k_top=None,
):
attention = torch.cat(
attention.sort(dim=-1, descending=True).indices < k_top
)
+ if residual is not None:
+ s = attention.sort(dim=-1)
+ m = 1 - (s.values.cumsum(-1) < residual).long()
+ b = m.new(attention.size()).scatter_(dim=-1, index=s.indices, src=m)
+ attention = attention * b
+
surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
ctx = cairo.Context(surface)
x, y = 0, 0
for d in range(attention.size(0)):
- if d > 0:
- for n in range(attention.size(-1)):
- xc, yc = n * token_gap, -d * layer_gap
- ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
- ctx.fill()
-
at = attention[d]
ni = torch.arange(at.size(0))[:, None].expand_as(at)
nj = torch.arange(at.size(1))[None, :].expand_as(at)
ctx.stroke()
y -= layer_gap
- for d in range(1, attention.size(0)):
+ for d in range(0, attention.size(0) + 1):
for n in range(attention.size(-1)):
xc, yc = n * token_gap, -d * layer_gap
ctx.set_source_rgb(1.0, 1.0, 1.0)
- ctx.arc(xc, yc, token_gap / 10 + 0.5, 0, 2 * math.pi)
+ ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
ctx.fill()
ctx.set_source_rgb(0.0, 0.0, 0.0)
- ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
+ ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi)
ctx.fill()
ctx.set_source_rgb(0.0, 0.0, 0.0)
x_advance,
y_advance,
) = ctx.text_extents(s)
- ctx.move_to(k * token_gap - width_t / 2, -y_bearing)
+ ctx.move_to(k * token_gap - width_t / 2, token_gap / 5 - y_bearing)
ctx.show_text(s)
for k, t in enumerate(tokens_output):
x_advance,
y_advance,
) = ctx.text_extents(s)
- ctx.move_to(k * token_gap - width_t / 2, -attention.size(0) * layer_gap)
+ ctx.move_to(
+ k * token_gap - width_t / 2, -token_gap / 5 - attention.size(0) * layer_gap
+ )
ctx.show_text(s)
x, y, width, height = surface.ink_extents()