From 6ca75e89749c2248274826dba3df6c249e365292 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jul 2023 21:17:53 +0200 Subject: [PATCH] Update. --- graph.py | 48 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/graph.py b/graph.py index 3c92e8d..a2554d2 100755 --- a/graph.py +++ b/graph.py @@ -11,28 +11,32 @@ import cairo ###################################################################### + + def save_attention_image( filename, tokens_input, tokens_output, - attention, + # An iterable set of BxHxTxT attention matrices + attention_arrays, n_sample=0, n_head=0, pixel_scale=8, - token_gap=10, + token_gap=15, layer_gap=25, y_eps=0.5, padding=10, # do not draw links with a lesser attention - min_att=0, - # draw only the strongest links necessary to have less than - # residual remaining - residual=None, + min_link_attention=0, + # draw only the strongest links necessary to reache + # min_total_attention + min_total_attention=None, # draw only the top k links k_top=None, + curved=True, ): attention = torch.cat( - [x[n_sample : n_sample + 1, n_head] for x in attention], dim=0 + [x[n_sample : n_sample + 1, n_head] for x in attention_arrays], dim=0 ) if k_top is not None: @@ -40,9 +44,9 @@ def save_attention_image( attention.sort(dim=-1, descending=True).indices < k_top ) - if residual is not None: + if min_total_attention is not None: s = attention.sort(dim=-1) - m = 1 - (s.values.cumsum(-1) < residual).long() + m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long() b = m.new(attention.size()).scatter_(dim=-1, index=s.indices, src=m) attention = attention * b @@ -67,12 +71,19 @@ def save_attention_image( ni = ni.flatten()[o] nj = nj.flatten()[o] for i, j, a in zip(ni, nj, at): - if a > 0 and a >= min_att: + if a > 0 and a >= min_link_attention: c = 1 - a.item() ctx.set_source_rgb(c, c, c) ctx.set_line_width(0.5) - ctx.move_to(j * token_gap, y - y_eps) - ctx.line_to(i * token_gap, y - layer_gap + y_eps) + ax, ay = j * token_gap, y - y_eps + ctx.move_to(ax, ay) + dx, dy = i * token_gap, y - layer_gap + y_eps + if curved: + bx, by = ax, ay - layer_gap * 0.5 + cx, cy = dx, dy + layer_gap * 0.5 + ctx.curve_to(bx, by, cx, cy, dx, dy) + else: + ctx.line_to(dx, dy) ctx.stroke() y -= layer_gap @@ -133,8 +144,8 @@ def save_attention_image( if __name__ == "__main__": import mygpt - tokens_output = ["bluh", 2, 3, 4, "blih"] - tokens_input = ["n/a"] + tokens_output[:-1] + tokens_output = ["", 2, 3, 4, ""] + tokens_input = [""] + tokens_output[:-1] vocabulary_size = 3 x = torch.randint(vocabulary_size, (1, len(tokens_input))) @@ -157,4 +168,11 @@ if __name__ == "__main__": attention = model.retrieve_attention() - save_attention_image("attention.pdf", tokens_input, tokens_output, attention) + save_attention_image( + "attention.pdf", + tokens_input, + tokens_output, + attention, + # k_top=2, + min_total_attention=0.9, + ) -- 2.39.5