projects
/
picoclvr.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
00b2d5e
)
Update.
author
François Fleuret
<francois@fleuret.org>
Sat, 22 Jul 2023 12:29:26 +0000
(14:29 +0200)
committer
François Fleuret
<francois@fleuret.org>
Sat, 22 Jul 2023 12:29:26 +0000
(14:29 +0200)
graph.py
patch
|
blob
|
history
diff --git
a/graph.py
b/graph.py
index
97de6d1
..
0db7bd0
100755
(executable)
--- a/
graph.py
+++ b/
graph.py
@@
-15,14
+15,11
@@
def save_attention_image(
filename,
tokens,
attention,
filename,
tokens,
attention,
- surface_width=128,
- surface_height=96,
pixel_scale=8,
pixel_scale=8,
- x=10,
- y=10,
- token_gap=15,
+ token_gap=10,
layer_gap=25,
layer_gap=25,
- y_eps=1,
+ y_eps=1.5,
+ padding=0,
min_att=1e-2,
):
# surface = cairo.PDFSurface(
min_att=1e-2,
):
# surface = cairo.PDFSurface(
@@
-38,7
+35,9
@@
def save_attention_image(
ctx.set_font_size(4.0)
# ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
ctx.set_font_size(4.0)
# ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
- u = []
+ x, y = 0, 0
+
+ u = {}
for n, t in enumerate(tokens):
string = str(t)
(
for n, t in enumerate(tokens):
string = str(t)
(
@@
-49,21
+48,14
@@
def save_attention_image(
x_advance,
y_advance,
) = ctx.text_extents(string)
x_advance,
y_advance,
) = ctx.text_extents(string)
- u
.append((n, string, x, x + width_t / 2, height_t, y_bearing)
)
+ u
[n]=(string, x, x + width_t / 2, height_t, y_bearing
)
x += x_advance + token_gap
tokens = u
for d in range(attention.size(0) + 1):
x += x_advance + token_gap
tokens = u
for d in range(attention.size(0) + 1):
- for n, s, x, xc, h, yb in tokens:
- # ctx.set_source_rgb(0.0, 0.0, 0.0)
- # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t)
- # ctx.stroke()
- ctx.set_source_rgb(0.0, 0.0, 0.0)
- ctx.move_to(x, y)
- ctx.show_text(s)
- # x += x_advance + 1
+ for n, (s, x, xc, h, yb) in tokens.items():
if d < attention.size(0):
if d < attention.size(0):
- for m,
_, _, x2c, h2, y2b in tokens
:
+ for m,
(_, _, x2c, h2, y2b) in tokens.items()
:
if attention[d, n, m] >= min_att:
c = 1 - attention[d, n, m]
ctx.set_source_rgb(c, c, c)
if attention[d, n, m] >= min_att:
c = 1 - attention[d, n, m]
ctx.set_source_rgb(c, c, c)
@@
-71,9
+63,20
@@
def save_attention_image(
ctx.move_to(xc, y + yb + h + y_eps)
ctx.line_to(x2c, y + layer_gap + y2b - y_eps)
ctx.stroke()
ctx.move_to(xc, y + yb + h + y_eps)
ctx.line_to(x2c, y + layer_gap + y2b - y_eps)
ctx.stroke()
+ # ctx.set_source_rgb(0.0, 0.0, 0.0)
+ # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t)
+ # ctx.stroke()
+ ctx.set_source_rgb(0.0, 0.0, 0.0)
+ ctx.move_to(x, y)
+ ctx.show_text(s)
+ # x += x_advance + 1
y += layer_gap
x, y, width, height = surface.ink_extents()
y += layer_gap
x, y, width, height = surface.ink_extents()
+ x -= padding
+ y -= padding
+ width += 2 * padding
+ height += 2 * padding
pdf_surface = cairo.PDFSurface(filename, width, height)
ctx_pdf = cairo.Context(pdf_surface)
ctx_pdf.set_source_surface(surface, -x, -y)
pdf_surface = cairo.PDFSurface(filename, width, height)
ctx_pdf = cairo.Context(pdf_surface)
ctx_pdf.set_source_surface(surface, -x, -y)
@@
-112,4
+115,4
@@
if __name__ == "__main__":
tokens = ["bluh", 2, 3, 4, "blih"]
attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1)
tokens = ["bluh", 2, 3, 4, "blih"]
attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1)
- save_attention_image("attention.pdf", tokens, attention)
+ save_attention_image("attention.pdf", tokens, attention
, padding=3
)