+ for d in range(attention.size(0)):
+ if d > 0:
+ for n in range(attention.size(-1)):
+ xc, yc = n * token_gap, -d * layer_gap
+ ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
+ ctx.fill()
+
+ at = attention[d]
+ ni = torch.arange(at.size(0))[:, None].expand_as(at)
+ nj = torch.arange(at.size(1))[None, :].expand_as(at)
+ at = at.flatten()
+ o = at.sort().indices
+ at = at[o]
+ ni = ni.flatten()[o]
+ nj = nj.flatten()[o]
+ for i, j, a in zip(ni, nj, at):
+ if a > 0 and a >= min_att:
+ c = 1 - a.item()
+ ctx.set_source_rgb(c, c, c)
+ ctx.set_line_width(0.5)
+ ctx.move_to(j * token_gap, y - y_eps)
+ ctx.line_to(i * token_gap, y - layer_gap + y_eps)
+ ctx.stroke()
+ y -= layer_gap
+
+ for d in range(1, attention.size(0)):
+ for n in range(attention.size(-1)):
+ xc, yc = n * token_gap, -d * layer_gap
+ ctx.set_source_rgb(1.0, 1.0, 1.0)
+ ctx.arc(xc, yc, token_gap / 10 + 0.5, 0, 2 * math.pi)
+ ctx.fill()
+ ctx.set_source_rgb(0.0, 0.0, 0.0)
+ ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
+ ctx.fill()
+
+ ctx.set_source_rgb(0.0, 0.0, 0.0)
+
+ for k, t in enumerate(tokens_input):
+ s = str(t)