Update.
[picoclvr.git] / graph.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 import cairo
11
12
13 ######################################################################
14 def save_attention_image(
15     filename,
16     tokens_input,
17     tokens_output,
18     attention,
19     n_sample=0,
20     n_head=0,
21     pixel_scale=8,
22     token_gap=10,
23     layer_gap=25,
24     y_eps=0.5,
25     padding=10,
26     # do not draw links with a lesser attention
27     min_att=0,
28     # draw only the strongest links necessary to have less than
29     # residual remaining
30     residual=None,
31     # draw only the top k links
32     k_top=None,
33 ):
34     attention = torch.cat(
35         [x[n_sample : n_sample + 1, n_head] for x in attention], dim=0
36     )
37
38     if k_top is not None:
39         attention = attention * (
40             attention.sort(dim=-1, descending=True).indices < k_top
41         )
42
43     if residual is not None:
44         s = attention.sort(dim=-1)
45         m = 1 - (s.values.cumsum(-1) < residual).long()
46         b = m.new(attention.size()).scatter_(dim=-1, index=s.indices, src=m)
47         attention = attention * b
48
49     surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
50
51     ctx = cairo.Context(surface)
52     ctx.scale(pixel_scale, pixel_scale)
53
54     ctx.set_source_rgb(0.0, 0.0, 0.0)
55     ctx.set_font_size(4.0)
56     # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
57
58     x, y = 0, 0
59
60     for d in range(attention.size(0)):
61         at = attention[d]
62         ni = torch.arange(at.size(0))[:, None].expand_as(at)
63         nj = torch.arange(at.size(1))[None, :].expand_as(at)
64         at = at.flatten()
65         o = at.sort().indices
66         at = at[o]
67         ni = ni.flatten()[o]
68         nj = nj.flatten()[o]
69         for i, j, a in zip(ni, nj, at):
70             if a > 0 and a >= min_att:
71                 c = 1 - a.item()
72                 ctx.set_source_rgb(c, c, c)
73                 ctx.set_line_width(0.5)
74                 ctx.move_to(j * token_gap, y - y_eps)
75                 ctx.line_to(i * token_gap, y - layer_gap + y_eps)
76                 ctx.stroke()
77         y -= layer_gap
78
79     for d in range(0, attention.size(0) + 1):
80         for n in range(attention.size(-1)):
81             xc, yc = n * token_gap, -d * layer_gap
82             ctx.set_source_rgb(1.0, 1.0, 1.0)
83             ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
84             ctx.fill()
85             ctx.set_source_rgb(0.0, 0.0, 0.0)
86             ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi)
87             ctx.fill()
88
89     ctx.set_source_rgb(0.0, 0.0, 0.0)
90
91     for k, t in enumerate(tokens_input):
92         s = str(t)
93         (
94             x_bearing,
95             y_bearing,
96             width_t,
97             height_t,
98             x_advance,
99             y_advance,
100         ) = ctx.text_extents(s)
101         ctx.move_to(k * token_gap - width_t / 2, token_gap / 5 - y_bearing)
102         ctx.show_text(s)
103
104     for k, t in enumerate(tokens_output):
105         s = str(t)
106         (
107             x_bearing,
108             y_bearing,
109             width_t,
110             height_t,
111             x_advance,
112             y_advance,
113         ) = ctx.text_extents(s)
114         ctx.move_to(
115             k * token_gap - width_t / 2, -token_gap / 5 - attention.size(0) * layer_gap
116         )
117         ctx.show_text(s)
118
119     x, y, width, height = surface.ink_extents()
120     x -= padding
121     y -= padding
122     width += 2 * padding
123     height += 2 * padding
124     pdf_surface = cairo.PDFSurface(filename, width, height)
125     ctx_pdf = cairo.Context(pdf_surface)
126     ctx_pdf.set_source_surface(surface, -x, -y)
127     ctx_pdf.paint()
128     pdf_surface.finish()
129
130
131 ######################################################################
132
133 if __name__ == "__main__":
134     import mygpt
135
136     tokens_output = ["bluh", 2, 3, 4, "blih"]
137     tokens_input = ["n/a"] + tokens_output[:-1]
138
139     vocabulary_size = 3
140     x = torch.randint(vocabulary_size, (1, len(tokens_input)))
141
142     model = mygpt.MyGPT(
143         vocabulary_size=vocabulary_size,
144         dim_model=4,
145         dim_keys=2,
146         dim_hidden=2,
147         nb_heads=2,
148         nb_blocks=3,
149         dropout=0.1,
150         causal=True,
151     )
152
153     model.eval()
154     model.record_attention()
155
156     y1 = model(mygpt.BracketedSequence(x)).x
157
158     attention = model.retrieve_attention()
159
160     save_attention_image("attention.pdf", tokens_input, tokens_output, attention)