From b59fca62aa31de18a3e0cd0bb54e395d4b1254ae Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jul 2023 14:29:26 +0200 Subject: [PATCH] Update. --- graph.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/graph.py b/graph.py index 97de6d1..0db7bd0 100755 --- a/graph.py +++ b/graph.py @@ -15,14 +15,11 @@ def save_attention_image( filename, tokens, attention, - surface_width=128, - surface_height=96, pixel_scale=8, - x=10, - y=10, - token_gap=15, + token_gap=10, layer_gap=25, - y_eps=1, + y_eps=1.5, + padding=0, min_att=1e-2, ): # surface = cairo.PDFSurface( @@ -38,7 +35,9 @@ def save_attention_image( ctx.set_font_size(4.0) # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL) - u = [] + x, y = 0, 0 + + u = {} for n, t in enumerate(tokens): string = str(t) ( @@ -49,21 +48,14 @@ def save_attention_image( x_advance, y_advance, ) = ctx.text_extents(string) - u.append((n, string, x, x + width_t / 2, height_t, y_bearing)) + u[n]=(string, x, x + width_t / 2, height_t, y_bearing) x += x_advance + token_gap tokens = u for d in range(attention.size(0) + 1): - for n, s, x, xc, h, yb in tokens: - # ctx.set_source_rgb(0.0, 0.0, 0.0) - # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t) - # ctx.stroke() - ctx.set_source_rgb(0.0, 0.0, 0.0) - ctx.move_to(x, y) - ctx.show_text(s) - # x += x_advance + 1 + for n, (s, x, xc, h, yb) in tokens.items(): if d < attention.size(0): - for m, _, _, x2c, h2, y2b in tokens: + for m, (_, _, x2c, h2, y2b) in tokens.items(): if attention[d, n, m] >= min_att: c = 1 - attention[d, n, m] ctx.set_source_rgb(c, c, c) @@ -71,9 +63,20 @@ def save_attention_image( ctx.move_to(xc, y + yb + h + y_eps) ctx.line_to(x2c, y + layer_gap + y2b - y_eps) ctx.stroke() + # ctx.set_source_rgb(0.0, 0.0, 0.0) + # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t) + # ctx.stroke() + ctx.set_source_rgb(0.0, 0.0, 0.0) + ctx.move_to(x, y) + ctx.show_text(s) + # x += x_advance + 1 y += layer_gap x, y, width, height = surface.ink_extents() + x -= padding + y -= padding + width += 2 * padding + height += 2 * padding pdf_surface = cairo.PDFSurface(filename, width, height) ctx_pdf = cairo.Context(pdf_surface) ctx_pdf.set_source_surface(surface, -x, -y) @@ -112,4 +115,4 @@ if __name__ == "__main__": tokens = ["bluh", 2, 3, 4, "blih"] attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1) - save_attention_image("attention.pdf", tokens, attention) + save_attention_image("attention.pdf", tokens, attention, padding=3) -- 2.39.5