Update.
[picoclvr.git] / graph.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 import cairo
11
12
13 ######################################################################
14 def save_attention_image(
15     filename,
16     tokens_input,
17     tokens_output,
18     attention,
19     n_sample=0,
20     n_head=0,
21     pixel_scale=8,
22     token_gap=10,
23     layer_gap=25,
24     y_eps=0.5,
25     padding=10,
26     min_att=0,
27     k_top=None,
28 ):
29     attention = torch.cat(
30         [x[n_sample : n_sample + 1, n_head] for x in attention], dim=0
31     )
32
33     if k_top is not None:
34         attention = attention * (
35             attention.sort(dim=-1, descending=True).indices < k_top
36         )
37
38     surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
39
40     ctx = cairo.Context(surface)
41     ctx.scale(pixel_scale, pixel_scale)
42
43     ctx.set_source_rgb(0.0, 0.0, 0.0)
44     ctx.set_font_size(4.0)
45     # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
46
47     x, y = 0, 0
48
49     for d in range(attention.size(0)):
50         if d > 0:
51             for n in range(attention.size(-1)):
52                 xc, yc = n * token_gap, -d * layer_gap
53                 ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
54                 ctx.fill()
55
56         at = attention[d]
57         ni = torch.arange(at.size(0))[:, None].expand_as(at)
58         nj = torch.arange(at.size(1))[None, :].expand_as(at)
59         at = at.flatten()
60         o = at.sort().indices
61         at = at[o]
62         ni = ni.flatten()[o]
63         nj = nj.flatten()[o]
64         for i, j, a in zip(ni, nj, at):
65             if a > 0 and a >= min_att:
66                 c = 1 - a.item()
67                 ctx.set_source_rgb(c, c, c)
68                 ctx.set_line_width(0.5)
69                 ctx.move_to(j * token_gap, y - y_eps)
70                 ctx.line_to(i * token_gap, y - layer_gap + y_eps)
71                 ctx.stroke()
72         y -= layer_gap
73
74     for d in range(1, attention.size(0)):
75         for n in range(attention.size(-1)):
76             xc, yc = n * token_gap, -d * layer_gap
77             ctx.set_source_rgb(1.0, 1.0, 1.0)
78             ctx.arc(xc, yc, token_gap / 10 + 0.5, 0, 2 * math.pi)
79             ctx.fill()
80             ctx.set_source_rgb(0.0, 0.0, 0.0)
81             ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
82             ctx.fill()
83
84     ctx.set_source_rgb(0.0, 0.0, 0.0)
85
86     for k, t in enumerate(tokens_input):
87         s = str(t)
88         (
89             x_bearing,
90             y_bearing,
91             width_t,
92             height_t,
93             x_advance,
94             y_advance,
95         ) = ctx.text_extents(s)
96         ctx.move_to(k * token_gap - width_t / 2, -y_bearing)
97         ctx.show_text(s)
98
99     for k, t in enumerate(tokens_output):
100         s = str(t)
101         (
102             x_bearing,
103             y_bearing,
104             width_t,
105             height_t,
106             x_advance,
107             y_advance,
108         ) = ctx.text_extents(s)
109         ctx.move_to(k * token_gap - width_t / 2, -attention.size(0) * layer_gap)
110         ctx.show_text(s)
111
112     x, y, width, height = surface.ink_extents()
113     x -= padding
114     y -= padding
115     width += 2 * padding
116     height += 2 * padding
117     pdf_surface = cairo.PDFSurface(filename, width, height)
118     ctx_pdf = cairo.Context(pdf_surface)
119     ctx_pdf.set_source_surface(surface, -x, -y)
120     ctx_pdf.paint()
121     pdf_surface.finish()
122
123
124 ######################################################################
125
126 if __name__ == "__main__":
127     import mygpt
128
129     tokens_output = ["bluh", 2, 3, 4, "blih"]
130     tokens_input = ["n/a"] + tokens_output[:-1]
131
132     vocabulary_size = 3
133     x = torch.randint(vocabulary_size, (1, len(tokens_input)))
134
135     model = mygpt.MyGPT(
136         vocabulary_size=vocabulary_size,
137         dim_model=4,
138         dim_keys=2,
139         dim_hidden=2,
140         nb_heads=2,
141         nb_blocks=3,
142         dropout=0.1,
143         causal=True,
144     )
145
146     model.eval()
147     model.record_attention()
148
149     y1 = model(mygpt.BracketedSequence(x)).x
150
151     attention = model.retrieve_attention()
152
153     save_attention_image("attention.pdf", tokens_input, tokens_output, attention)