X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=graph.py;h=5bab86183b0bd252f2e491a47e59b09a139711d2;hb=ef3bef5253ff719953dfffff28d4122c19acdd77;hp=97de6d108daf435553fe8cb22c3c427a1af391fe;hpb=00b2d5ed01fb523fbc4e699f0419329efbee0ea8;p=picoclvr.git diff --git a/graph.py b/graph.py index 97de6d1..5bab861 100755 --- a/graph.py +++ b/graph.py @@ -13,21 +13,27 @@ import cairo ###################################################################### def save_attention_image( filename, - tokens, + tokens_input, + tokens_output, attention, - surface_width=128, - surface_height=96, + n_sample=0, + n_head=0, pixel_scale=8, - x=10, - y=10, - token_gap=15, + token_gap=10, layer_gap=25, - y_eps=1, - min_att=1e-2, + y_eps=0.5, + padding=10, + min_att=0, + k_top=None, ): - # surface = cairo.PDFSurface( - # filename, surface_width * pixel_scale, surface_height * pixel_scale - # ) + attention = torch.cat( + [x[n_sample : n_sample + 1, n_head] for x in attention], dim=0 + ) + + if k_top is not None: + attention = attention * ( + attention.sort(dim=-1, descending=True).indices < k_top + ) surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None) @@ -38,9 +44,47 @@ def save_attention_image( ctx.set_font_size(4.0) # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL) - u = [] - for n, t in enumerate(tokens): - string = str(t) + x, y = 0, 0 + + for d in range(attention.size(0)): + if d > 0: + for n in range(attention.size(-1)): + xc, yc = n * token_gap, -d * layer_gap + ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) + ctx.fill() + + at = attention[d] + ni = torch.arange(at.size(0))[:, None].expand_as(at) + nj = torch.arange(at.size(1))[None, :].expand_as(at) + at = at.flatten() + o = at.sort().indices + at = at[o] + ni = ni.flatten()[o] + nj = nj.flatten()[o] + for i, j, a in zip(ni, nj, at): + if a > 0 and a >= min_att: + c = 1 - a.item() + ctx.set_source_rgb(c, c, c) + ctx.set_line_width(0.5) + ctx.move_to(j * token_gap, y - y_eps) + ctx.line_to(i * token_gap, y - layer_gap + y_eps) + ctx.stroke() + y -= layer_gap + + for d in range(1, attention.size(0)): + for n in range(attention.size(-1)): + xc, yc = n * token_gap, -d * layer_gap + ctx.set_source_rgb(1.0, 1.0, 1.0) + ctx.arc(xc, yc, token_gap / 10 + 0.5, 0, 2 * math.pi) + ctx.fill() + ctx.set_source_rgb(0.0, 0.0, 0.0) + ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) + ctx.fill() + + ctx.set_source_rgb(0.0, 0.0, 0.0) + + for k, t in enumerate(tokens_input): + s = str(t) ( x_bearing, y_bearing, @@ -48,32 +92,28 @@ def save_attention_image( height_t, x_advance, y_advance, - ) = ctx.text_extents(string) - u.append((n, string, x, x + width_t / 2, height_t, y_bearing)) - x += x_advance + token_gap - tokens = u - - for d in range(attention.size(0) + 1): - for n, s, x, xc, h, yb in tokens: - # ctx.set_source_rgb(0.0, 0.0, 0.0) - # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t) - # ctx.stroke() - ctx.set_source_rgb(0.0, 0.0, 0.0) - ctx.move_to(x, y) - ctx.show_text(s) - # x += x_advance + 1 - if d < attention.size(0): - for m, _, _, x2c, h2, y2b in tokens: - if attention[d, n, m] >= min_att: - c = 1 - attention[d, n, m] - ctx.set_source_rgb(c, c, c) - ctx.set_line_width(0.5) - ctx.move_to(xc, y + yb + h + y_eps) - ctx.line_to(x2c, y + layer_gap + y2b - y_eps) - ctx.stroke() - y += layer_gap + ) = ctx.text_extents(s) + ctx.move_to(k * token_gap - width_t / 2, -y_bearing) + ctx.show_text(s) + + for k, t in enumerate(tokens_output): + s = str(t) + ( + x_bearing, + y_bearing, + width_t, + height_t, + x_advance, + y_advance, + ) = ctx.text_extents(s) + ctx.move_to(k * token_gap - width_t / 2, -attention.size(0) * layer_gap) + ctx.show_text(s) x, y, width, height = surface.ink_extents() + x -= padding + y -= padding + width += 2 * padding + height += 2 * padding pdf_surface = cairo.PDFSurface(filename, width, height) ctx_pdf = cairo.Context(pdf_surface) ctx_pdf.set_source_surface(surface, -x, -y) @@ -86,8 +126,11 @@ def save_attention_image( if __name__ == "__main__": import mygpt + tokens_output = ["bluh", 2, 3, 4, "blih"] + tokens_input = ["n/a"] + tokens_output[:-1] + vocabulary_size = 3 - x = torch.randint(vocabulary_size, (1, 5)) + x = torch.randint(vocabulary_size, (1, len(tokens_input))) model = mygpt.MyGPT( vocabulary_size=vocabulary_size, @@ -105,11 +148,6 @@ if __name__ == "__main__": y1 = model(mygpt.BracketedSequence(x)).x - a = model.retrieve_attention() - print(a) - attention = torch.cat([x[:0] for x in a], dim=0) - - tokens = ["bluh", 2, 3, 4, "blih"] - attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1) + attention = model.retrieve_attention() - save_attention_image("attention.pdf", tokens, attention) + save_attention_image("attention.pdf", tokens_input, tokens_output, attention)