X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=graph.py;h=5bab86183b0bd252f2e491a47e59b09a139711d2;hb=ef3bef5253ff719953dfffff28d4122c19acdd77;hp=0db7bd0b08ba8a74db9adbef4d03dc90a92ce8cc;hpb=b59fca62aa31de18a3e0cd0bb54e395d4b1254ae;p=picoclvr.git diff --git a/graph.py b/graph.py index 0db7bd0..5bab861 100755 --- a/graph.py +++ b/graph.py @@ -13,18 +13,27 @@ import cairo ###################################################################### def save_attention_image( filename, - tokens, + tokens_input, + tokens_output, attention, + n_sample=0, + n_head=0, pixel_scale=8, token_gap=10, layer_gap=25, - y_eps=1.5, - padding=0, - min_att=1e-2, + y_eps=0.5, + padding=10, + min_att=0, + k_top=None, ): - # surface = cairo.PDFSurface( - # filename, surface_width * pixel_scale, surface_height * pixel_scale - # ) + attention = torch.cat( + [x[n_sample : n_sample + 1, n_head] for x in attention], dim=0 + ) + + if k_top is not None: + attention = attention * ( + attention.sort(dim=-1, descending=True).indices < k_top + ) surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None) @@ -37,9 +46,45 @@ def save_attention_image( x, y = 0, 0 - u = {} - for n, t in enumerate(tokens): - string = str(t) + for d in range(attention.size(0)): + if d > 0: + for n in range(attention.size(-1)): + xc, yc = n * token_gap, -d * layer_gap + ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) + ctx.fill() + + at = attention[d] + ni = torch.arange(at.size(0))[:, None].expand_as(at) + nj = torch.arange(at.size(1))[None, :].expand_as(at) + at = at.flatten() + o = at.sort().indices + at = at[o] + ni = ni.flatten()[o] + nj = nj.flatten()[o] + for i, j, a in zip(ni, nj, at): + if a > 0 and a >= min_att: + c = 1 - a.item() + ctx.set_source_rgb(c, c, c) + ctx.set_line_width(0.5) + ctx.move_to(j * token_gap, y - y_eps) + ctx.line_to(i * token_gap, y - layer_gap + y_eps) + ctx.stroke() + y -= layer_gap + + for d in range(1, attention.size(0)): + for n in range(attention.size(-1)): + xc, yc = n * token_gap, -d * layer_gap + ctx.set_source_rgb(1.0, 1.0, 1.0) + ctx.arc(xc, yc, token_gap / 10 + 0.5, 0, 2 * math.pi) + ctx.fill() + ctx.set_source_rgb(0.0, 0.0, 0.0) + ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) + ctx.fill() + + ctx.set_source_rgb(0.0, 0.0, 0.0) + + for k, t in enumerate(tokens_input): + s = str(t) ( x_bearing, y_bearing, @@ -47,30 +92,22 @@ def save_attention_image( height_t, x_advance, y_advance, - ) = ctx.text_extents(string) - u[n]=(string, x, x + width_t / 2, height_t, y_bearing) - x += x_advance + token_gap - tokens = u - - for d in range(attention.size(0) + 1): - for n, (s, x, xc, h, yb) in tokens.items(): - if d < attention.size(0): - for m, (_, _, x2c, h2, y2b) in tokens.items(): - if attention[d, n, m] >= min_att: - c = 1 - attention[d, n, m] - ctx.set_source_rgb(c, c, c) - ctx.set_line_width(0.5) - ctx.move_to(xc, y + yb + h + y_eps) - ctx.line_to(x2c, y + layer_gap + y2b - y_eps) - ctx.stroke() - # ctx.set_source_rgb(0.0, 0.0, 0.0) - # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t) - # ctx.stroke() - ctx.set_source_rgb(0.0, 0.0, 0.0) - ctx.move_to(x, y) - ctx.show_text(s) - # x += x_advance + 1 - y += layer_gap + ) = ctx.text_extents(s) + ctx.move_to(k * token_gap - width_t / 2, -y_bearing) + ctx.show_text(s) + + for k, t in enumerate(tokens_output): + s = str(t) + ( + x_bearing, + y_bearing, + width_t, + height_t, + x_advance, + y_advance, + ) = ctx.text_extents(s) + ctx.move_to(k * token_gap - width_t / 2, -attention.size(0) * layer_gap) + ctx.show_text(s) x, y, width, height = surface.ink_extents() x -= padding @@ -89,8 +126,11 @@ def save_attention_image( if __name__ == "__main__": import mygpt + tokens_output = ["bluh", 2, 3, 4, "blih"] + tokens_input = ["n/a"] + tokens_output[:-1] + vocabulary_size = 3 - x = torch.randint(vocabulary_size, (1, 5)) + x = torch.randint(vocabulary_size, (1, len(tokens_input))) model = mygpt.MyGPT( vocabulary_size=vocabulary_size, @@ -108,11 +148,6 @@ if __name__ == "__main__": y1 = model(mygpt.BracketedSequence(x)).x - a = model.retrieve_attention() - print(a) - attention = torch.cat([x[:0] for x in a], dim=0) - - tokens = ["bluh", 2, 3, 4, "blih"] - attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1) + attention = model.retrieve_attention() - save_attention_image("attention.pdf", tokens, attention, padding=3) + save_attention_image("attention.pdf", tokens_input, tokens_output, attention)