X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=graph.py;h=07e376a3af86d24874523775e26f16132a489512;hb=21ed4aa91d0f1ac87ec684d8808e5ced552ad457;hp=3c92e8d2c12d9027ccc9dba87c572992eaf19e89;hpb=459640c9228bf0ee4b817bfaeb1863ea4fce45b3;p=picoclvr.git diff --git a/graph.py b/graph.py index 3c92e8d..07e376a 100755 --- a/graph.py +++ b/graph.py @@ -11,40 +11,43 @@ import cairo ###################################################################### + + def save_attention_image( + # image to save filename, tokens_input, tokens_output, - attention, - n_sample=0, - n_head=0, + # list of 2d tensors T2xT1, T3xT2, ..., TkxTk-1 + attention_matrices, + # do not draw links with a lesser attention + min_link_attention=0, + # draw only the strongest links necessary so that their summed + # attention is above min_total_attention + min_total_attention=None, + # draw only the top k links + k_top=None, + # the purely graphical settings + curved=True, pixel_scale=8, - token_gap=10, + token_gap=15, layer_gap=25, y_eps=0.5, padding=10, - # do not draw links with a lesser attention - min_att=0, - # draw only the strongest links necessary to have less than - # residual remaining - residual=None, - # draw only the top k links - k_top=None, ): - attention = torch.cat( - [x[n_sample : n_sample + 1, n_head] for x in attention], dim=0 - ) - if k_top is not None: - attention = attention * ( - attention.sort(dim=-1, descending=True).indices < k_top - ) - - if residual is not None: - s = attention.sort(dim=-1) - m = 1 - (s.values.cumsum(-1) < residual).long() - b = m.new(attention.size()).scatter_(dim=-1, index=s.indices, src=m) - attention = attention * b + am = [] + for m in attention_matrices: + am.append(m * (m.sort(dim=-1, descending=True).indices < k_top)) + attention_matrices = am + + if min_total_attention is not None: + am = [] + for m in attention_matrices: + s = m.sort(dim=-1) + m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long() + b = m.new(m.size()).scatter_(dim=-1, index=s.indices, src=m) + am.append(m * b) surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None) @@ -57,8 +60,9 @@ def save_attention_image( x, y = 0, 0 - for d in range(attention.size(0)): - at = attention[d] + ctx.set_line_width(0.25) + for d in range(len(attention_matrices)): + at = attention_matrices[d].to("cpu") ni = torch.arange(at.size(0))[:, None].expand_as(at) nj = torch.arange(at.size(1))[None, :].expand_as(at) at = at.flatten() @@ -67,17 +71,28 @@ def save_attention_image( ni = ni.flatten()[o] nj = nj.flatten()[o] for i, j, a in zip(ni, nj, at): - if a > 0 and a >= min_att: + if a > 0 and a >= min_link_attention: c = 1 - a.item() ctx.set_source_rgb(c, c, c) - ctx.set_line_width(0.5) - ctx.move_to(j * token_gap, y - y_eps) - ctx.line_to(i * token_gap, y - layer_gap + y_eps) + ax, ay = j * token_gap, y - y_eps + ctx.move_to(ax, ay) + dx, dy = i * token_gap, y - layer_gap + y_eps + if curved: + bx, by = ax, ay - layer_gap * 0.5 + cx, cy = dx, dy + layer_gap * 0.5 + ctx.curve_to(bx, by, cx, cy, dx, dy) + else: + ctx.line_to(dx, dy) ctx.stroke() y -= layer_gap - for d in range(0, attention.size(0) + 1): - for n in range(attention.size(-1)): + for d in range(0, len(attention_matrices) + 1): + n = ( + attention_matrices[0].size(-1) + if d == 0 + else attention_matrices[d - 1].size(-2) + ) + for n in range(n): xc, yc = n * token_gap, -d * layer_gap ctx.set_source_rgb(1.0, 1.0, 1.0) ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) @@ -98,7 +113,7 @@ def save_attention_image( x_advance, y_advance, ) = ctx.text_extents(s) - ctx.move_to(k * token_gap - width_t / 2, token_gap / 5 - y_bearing) + ctx.move_to(k * token_gap - width_t / 2, 2 * token_gap / 5) ctx.show_text(s) for k, t in enumerate(tokens_output): @@ -112,7 +127,8 @@ def save_attention_image( y_advance, ) = ctx.text_extents(s) ctx.move_to( - k * token_gap - width_t / 2, -token_gap / 5 - attention.size(0) * layer_gap + k * token_gap - width_t / 2, + -token_gap / 5 - len(attention_matrices) * layer_gap, ) ctx.show_text(s) @@ -133,8 +149,8 @@ def save_attention_image( if __name__ == "__main__": import mygpt - tokens_output = ["bluh", 2, 3, 4, "blih"] - tokens_input = ["n/a"] + tokens_output[:-1] + tokens_output = ["", "-", 3, 4, ""] + tokens_input = [""] + tokens_output[:-1] vocabulary_size = 3 x = torch.randint(vocabulary_size, (1, len(tokens_input))) @@ -145,7 +161,7 @@ if __name__ == "__main__": dim_keys=2, dim_hidden=2, nb_heads=2, - nb_blocks=3, + nb_blocks=5, dropout=0.1, causal=True, ) @@ -155,6 +171,15 @@ if __name__ == "__main__": y1 = model(mygpt.BracketedSequence(x)).x - attention = model.retrieve_attention() + attention_matrices = [m[0, 0] for m in model.retrieve_attention()] + + # attention_matrices = [torch.rand(*s) for s in [ (4,5),(3,4),(8,3),(5,8) ]] - save_attention_image("attention.pdf", tokens_input, tokens_output, attention) + save_attention_image( + "attention.pdf", + tokens_input, + tokens_output, + attention_matrices, + # k_top=2, + min_total_attention=0.9, + )