Update.
[picoclvr.git] / graph.py
index 5bab861..3c92e8d 100755 (executable)
--- a/graph.py
+++ b/graph.py
@@ -23,7 +23,12 @@ def save_attention_image(
     layer_gap=25,
     y_eps=0.5,
     padding=10,
+    # do not draw links with a lesser attention
     min_att=0,
+    # draw only the strongest links necessary to have less than
+    # residual remaining
+    residual=None,
+    # draw only the top k links
     k_top=None,
 ):
     attention = torch.cat(
@@ -35,6 +40,12 @@ def save_attention_image(
             attention.sort(dim=-1, descending=True).indices < k_top
         )
 
+    if residual is not None:
+        s = attention.sort(dim=-1)
+        m = 1 - (s.values.cumsum(-1) < residual).long()
+        b = m.new(attention.size()).scatter_(dim=-1, index=s.indices, src=m)
+        attention = attention * b
+
     surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
 
     ctx = cairo.Context(surface)
@@ -47,12 +58,6 @@ def save_attention_image(
     x, y = 0, 0
 
     for d in range(attention.size(0)):
-        if d > 0:
-            for n in range(attention.size(-1)):
-                xc, yc = n * token_gap, -d * layer_gap
-                ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
-                ctx.fill()
-
         at = attention[d]
         ni = torch.arange(at.size(0))[:, None].expand_as(at)
         nj = torch.arange(at.size(1))[None, :].expand_as(at)
@@ -71,14 +76,14 @@ def save_attention_image(
                 ctx.stroke()
         y -= layer_gap
 
-    for d in range(1, attention.size(0)):
+    for d in range(0, attention.size(0) + 1):
         for n in range(attention.size(-1)):
             xc, yc = n * token_gap, -d * layer_gap
             ctx.set_source_rgb(1.0, 1.0, 1.0)
-            ctx.arc(xc, yc, token_gap / 10 + 0.5, 0, 2 * math.pi)
+            ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
             ctx.fill()
             ctx.set_source_rgb(0.0, 0.0, 0.0)
-            ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
+            ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi)
             ctx.fill()
 
     ctx.set_source_rgb(0.0, 0.0, 0.0)
@@ -93,7 +98,7 @@ def save_attention_image(
             x_advance,
             y_advance,
         ) = ctx.text_extents(s)
-        ctx.move_to(k * token_gap - width_t / 2, -y_bearing)
+        ctx.move_to(k * token_gap - width_t / 2, token_gap / 5 - y_bearing)
         ctx.show_text(s)
 
     for k, t in enumerate(tokens_output):
@@ -106,7 +111,9 @@ def save_attention_image(
             x_advance,
             y_advance,
         ) = ctx.text_extents(s)
-        ctx.move_to(k * token_gap - width_t / 2, -attention.size(0) * layer_gap)
+        ctx.move_to(
+            k * token_gap - width_t / 2, -token_gap / 5 - attention.size(0) * layer_gap
+        )
         ctx.show_text(s)
 
     x, y, width, height = surface.ink_extents()