5 import torch, torchvision
8 from torch.nn import functional as F
13 ######################################################################
14 def save_attention_image(
25 # surface = cairo.PDFSurface(
26 # filename, surface_width * pixel_scale, surface_height * pixel_scale
29 surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
31 ctx = cairo.Context(surface)
32 ctx.scale(pixel_scale, pixel_scale)
34 ctx.set_source_rgb(0.0, 0.0, 0.0)
35 ctx.set_font_size(4.0)
36 # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
41 for n, t in enumerate(tokens):
50 ) = ctx.text_extents(string)
51 u[n]=(string, x, x + width_t / 2, height_t, y_bearing)
52 x += x_advance + token_gap
55 for d in range(attention.size(0) + 1):
56 for n, (s, x, xc, h, yb) in tokens.items():
57 if d < attention.size(0):
58 for m, (_, _, x2c, h2, y2b) in tokens.items():
59 if attention[d, n, m] >= min_att:
60 c = 1 - attention[d, n, m]
61 ctx.set_source_rgb(c, c, c)
62 ctx.set_line_width(0.5)
63 ctx.move_to(xc, y + yb + h + y_eps)
64 ctx.line_to(x2c, y + layer_gap + y2b - y_eps)
66 # ctx.set_source_rgb(0.0, 0.0, 0.0)
67 # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t)
69 ctx.set_source_rgb(0.0, 0.0, 0.0)
75 x, y, width, height = surface.ink_extents()
80 pdf_surface = cairo.PDFSurface(filename, width, height)
81 ctx_pdf = cairo.Context(pdf_surface)
82 ctx_pdf.set_source_surface(surface, -x, -y)
87 ######################################################################
89 if __name__ == "__main__":
93 x = torch.randint(vocabulary_size, (1, 5))
96 vocabulary_size=vocabulary_size,
107 model.record_attention()
109 y1 = model(mygpt.BracketedSequence(x)).x
111 a = model.retrieve_attention()
113 attention = torch.cat([x[:0] for x in a], dim=0)
115 tokens = ["bluh", 2, 3, 4, "blih"]
116 attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1)
118 save_attention_image("attention.pdf", tokens, attention, padding=3)