Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jul 2023 17:05:38 +0000 (19:05 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jul 2023 17:05:38 +0000 (19:05 +0200)
graph.py

index 5bab861..5195cc9 100755 (executable)
--- a/graph.py
+++ b/graph.py
@@ -47,12 +47,6 @@ def save_attention_image(
     x, y = 0, 0
 
     for d in range(attention.size(0)):
-        if d > 0:
-            for n in range(attention.size(-1)):
-                xc, yc = n * token_gap, -d * layer_gap
-                ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
-                ctx.fill()
-
         at = attention[d]
         ni = torch.arange(at.size(0))[:, None].expand_as(at)
         nj = torch.arange(at.size(1))[None, :].expand_as(at)
@@ -71,14 +65,14 @@ def save_attention_image(
                 ctx.stroke()
         y -= layer_gap
 
-    for d in range(1, attention.size(0)):
+    for d in range(0, attention.size(0) + 1):
         for n in range(attention.size(-1)):
             xc, yc = n * token_gap, -d * layer_gap
             ctx.set_source_rgb(1.0, 1.0, 1.0)
-            ctx.arc(xc, yc, token_gap / 10 + 0.5, 0, 2 * math.pi)
+            ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
             ctx.fill()
             ctx.set_source_rgb(0.0, 0.0, 0.0)
-            ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
+            ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi)
             ctx.fill()
 
     ctx.set_source_rgb(0.0, 0.0, 0.0)
@@ -93,7 +87,7 @@ def save_attention_image(
             x_advance,
             y_advance,
         ) = ctx.text_extents(s)
-        ctx.move_to(k * token_gap - width_t / 2, -y_bearing)
+        ctx.move_to(k * token_gap - width_t / 2, token_gap / 5 - y_bearing)
         ctx.show_text(s)
 
     for k, t in enumerate(tokens_output):
@@ -106,7 +100,9 @@ def save_attention_image(
             x_advance,
             y_advance,
         ) = ctx.text_extents(s)
-        ctx.move_to(k * token_gap - width_t / 2, -attention.size(0) * layer_gap)
+        ctx.move_to(
+            k * token_gap - width_t / 2, -token_gap / 5 - attention.size(0) * layer_gap
+        )
         ctx.show_text(s)
 
     x, y, width, height = surface.ink_extents()