x, y = 0, 0
for d in range(attention.size(0)):
- if d > 0:
- for n in range(attention.size(-1)):
- xc, yc = n * token_gap, -d * layer_gap
- ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
- ctx.fill()
-
at = attention[d]
ni = torch.arange(at.size(0))[:, None].expand_as(at)
nj = torch.arange(at.size(1))[None, :].expand_as(at)
ctx.stroke()
y -= layer_gap
- for d in range(1, attention.size(0)):
+ for d in range(0, attention.size(0) + 1):
for n in range(attention.size(-1)):
xc, yc = n * token_gap, -d * layer_gap
ctx.set_source_rgb(1.0, 1.0, 1.0)
- ctx.arc(xc, yc, token_gap / 10 + 0.5, 0, 2 * math.pi)
+ ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
ctx.fill()
ctx.set_source_rgb(0.0, 0.0, 0.0)
- ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
+ ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi)
ctx.fill()
ctx.set_source_rgb(0.0, 0.0, 0.0)
x_advance,
y_advance,
) = ctx.text_extents(s)
- ctx.move_to(k * token_gap - width_t / 2, -y_bearing)
+ ctx.move_to(k * token_gap - width_t / 2, token_gap / 5 - y_bearing)
ctx.show_text(s)
for k, t in enumerate(tokens_output):
x_advance,
y_advance,
) = ctx.text_extents(s)
- ctx.move_to(k * token_gap - width_t / 2, -attention.size(0) * layer_gap)
+ ctx.move_to(
+ k * token_gap - width_t / 2, -token_gap / 5 - attention.size(0) * layer_gap
+ )
ctx.show_text(s)
x, y, width, height = surface.ink_extents()