projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
graph.py
diff --git
a/graph.py
b/graph.py
index
5bab861
..
3c92e8d
100755
(executable)
--- a/
graph.py
+++ b/
graph.py
@@
-23,7
+23,12
@@
def save_attention_image(
layer_gap=25,
y_eps=0.5,
padding=10,
layer_gap=25,
y_eps=0.5,
padding=10,
+ # do not draw links with a lesser attention
min_att=0,
min_att=0,
+ # draw only the strongest links necessary to have less than
+ # residual remaining
+ residual=None,
+ # draw only the top k links
k_top=None,
):
attention = torch.cat(
k_top=None,
):
attention = torch.cat(
@@
-35,6
+40,12
@@
def save_attention_image(
attention.sort(dim=-1, descending=True).indices < k_top
)
attention.sort(dim=-1, descending=True).indices < k_top
)
+ if residual is not None:
+ s = attention.sort(dim=-1)
+ m = 1 - (s.values.cumsum(-1) < residual).long()
+ b = m.new(attention.size()).scatter_(dim=-1, index=s.indices, src=m)
+ attention = attention * b
+
surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
ctx = cairo.Context(surface)
surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
ctx = cairo.Context(surface)
@@
-47,12
+58,6
@@
def save_attention_image(
x, y = 0, 0
for d in range(attention.size(0)):
x, y = 0, 0
for d in range(attention.size(0)):
- if d > 0:
- for n in range(attention.size(-1)):
- xc, yc = n * token_gap, -d * layer_gap
- ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
- ctx.fill()
-
at = attention[d]
ni = torch.arange(at.size(0))[:, None].expand_as(at)
nj = torch.arange(at.size(1))[None, :].expand_as(at)
at = attention[d]
ni = torch.arange(at.size(0))[:, None].expand_as(at)
nj = torch.arange(at.size(1))[None, :].expand_as(at)
@@
-71,14
+76,14
@@
def save_attention_image(
ctx.stroke()
y -= layer_gap
ctx.stroke()
y -= layer_gap
- for d in range(
1, attention.size(0)
):
+ for d in range(
0, attention.size(0) + 1
):
for n in range(attention.size(-1)):
xc, yc = n * token_gap, -d * layer_gap
ctx.set_source_rgb(1.0, 1.0, 1.0)
for n in range(attention.size(-1)):
xc, yc = n * token_gap, -d * layer_gap
ctx.set_source_rgb(1.0, 1.0, 1.0)
- ctx.arc(xc, yc, token_gap / 10
+ 0.5
, 0, 2 * math.pi)
+ ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
ctx.fill()
ctx.set_source_rgb(0.0, 0.0, 0.0)
ctx.fill()
ctx.set_source_rgb(0.0, 0.0, 0.0)
- ctx.arc(xc, yc, token_gap /
1
0, 0, 2 * math.pi)
+ ctx.arc(xc, yc, token_gap /
2
0, 0, 2 * math.pi)
ctx.fill()
ctx.set_source_rgb(0.0, 0.0, 0.0)
ctx.fill()
ctx.set_source_rgb(0.0, 0.0, 0.0)
@@
-93,7
+98,7
@@
def save_attention_image(
x_advance,
y_advance,
) = ctx.text_extents(s)
x_advance,
y_advance,
) = ctx.text_extents(s)
- ctx.move_to(k * token_gap - width_t / 2,
-
y_bearing)
+ ctx.move_to(k * token_gap - width_t / 2,
token_gap / 5 -
y_bearing)
ctx.show_text(s)
for k, t in enumerate(tokens_output):
ctx.show_text(s)
for k, t in enumerate(tokens_output):
@@
-106,7
+111,9
@@
def save_attention_image(
x_advance,
y_advance,
) = ctx.text_extents(s)
x_advance,
y_advance,
) = ctx.text_extents(s)
- ctx.move_to(k * token_gap - width_t / 2, -attention.size(0) * layer_gap)
+ ctx.move_to(
+ k * token_gap - width_t / 2, -token_gap / 5 - attention.size(0) * layer_gap
+ )
ctx.show_text(s)
x, y, width, height = surface.ink_extents()
ctx.show_text(s)
x, y, width, height = surface.ink_extents()