Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jul 2023 20:52:54 +0000 (22:52 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jul 2023 20:52:54 +0000 (22:52 +0200)
graph.py
tasks.py

index a2554d2..a819283 100755 (executable)
--- a/graph.py
+++ b/graph.py
@@ -18,9 +18,7 @@ def save_attention_image(
     tokens_input,
     tokens_output,
     # An iterable set of BxHxTxT attention matrices
-    attention_arrays,
-    n_sample=0,
-    n_head=0,
+    attention_matrices,
     pixel_scale=8,
     token_gap=15,
     layer_gap=25,
@@ -35,20 +33,19 @@ def save_attention_image(
     k_top=None,
     curved=True,
 ):
-    attention = torch.cat(
-        [x[n_sample : n_sample + 1, n_head] for x in attention_arrays], dim=0
-    )
-
     if k_top is not None:
-        attention = attention * (
-            attention.sort(dim=-1, descending=True).indices < k_top
-        )
+        am = []
+        for m in attention_matrices:
+            am.append(m * (m.sort(dim=-1, descending=True).indices < k_top))
+        attention_matrices = am
 
     if min_total_attention is not None:
-        s = attention.sort(dim=-1)
-        m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long()
-        b = m.new(attention.size()).scatter_(dim=-1, index=s.indices, src=m)
-        attention = attention * b
+        am = []
+        for m in attention_matrices:
+            s = m.sort(dim=-1)
+            m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long()
+            b = m.new(m.size()).scatter_(dim=-1, index=s.indices, src=m)
+            am.append(m * b)
 
     surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
 
@@ -61,8 +58,9 @@ def save_attention_image(
 
     x, y = 0, 0
 
-    for d in range(attention.size(0)):
-        at = attention[d]
+    ctx.set_line_width(0.25)
+    for d in range(len(attention_matrices)):
+        at = attention_matrices[d]
         ni = torch.arange(at.size(0))[:, None].expand_as(at)
         nj = torch.arange(at.size(1))[None, :].expand_as(at)
         at = at.flatten()
@@ -74,7 +72,6 @@ def save_attention_image(
             if a > 0 and a >= min_link_attention:
                 c = 1 - a.item()
                 ctx.set_source_rgb(c, c, c)
-                ctx.set_line_width(0.5)
                 ax, ay = j * token_gap, y - y_eps
                 ctx.move_to(ax, ay)
                 dx, dy = i * token_gap, y - layer_gap + y_eps
@@ -87,8 +84,13 @@ def save_attention_image(
                 ctx.stroke()
         y -= layer_gap
 
-    for d in range(0, attention.size(0) + 1):
-        for n in range(attention.size(-1)):
+    for d in range(0, len(attention_matrices) + 1):
+        n = (
+            attention_matrices[0].size(-1)
+            if d == 0
+            else attention_matrices[d - 1].size(-2)
+        )
+        for n in range(n):
             xc, yc = n * token_gap, -d * layer_gap
             ctx.set_source_rgb(1.0, 1.0, 1.0)
             ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
@@ -123,7 +125,8 @@ def save_attention_image(
             y_advance,
         ) = ctx.text_extents(s)
         ctx.move_to(
-            k * token_gap - width_t / 2, -token_gap / 5 - attention.size(0) * layer_gap
+            k * token_gap - width_t / 2,
+            -token_gap / 5 - len(attention_matrices) * layer_gap,
         )
         ctx.show_text(s)
 
@@ -156,7 +159,7 @@ if __name__ == "__main__":
         dim_keys=2,
         dim_hidden=2,
         nb_heads=2,
-        nb_blocks=3,
+        nb_blocks=5,
         dropout=0.1,
         causal=True,
     )
@@ -166,13 +169,16 @@ if __name__ == "__main__":
 
     y1 = model(mygpt.BracketedSequence(x)).x
 
-    attention = model.retrieve_attention()
+    attention_matrices = [m[0, 0] for m in model.retrieve_attention()]
+
+    # attention_matrices = [ torch.rand(3,5), torch.rand(8,3), torch.rand(5,8) ]
+    # for a in attention_matrices: a=a/a.sum(-1,keepdim=True)
 
     save_attention_image(
         "attention.pdf",
         tokens_input,
         tokens_output,
-        attention,
+        attention_matrices,
         # k_top=2,
         min_total_attention=0.9,
     )
index 0eed2aa..0c92af9 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1284,7 +1284,7 @@ class RPL(Task):
         )
 
         if save_attention_image is not None:
-            input = self.test_input[:10]
+            input = self.test_input[:1]
             result = input.clone()
             s = (result == self.t_prog).long()
             ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
@@ -1305,24 +1305,23 @@ class RPL(Task):
                 model.record_attention(True)
                 model(BracketedSequence(result))
                 model.train(t)
-                attention = model.retrieve_attention()
+                ram = model.retrieve_attention()
                 model.record_attention(False)
 
-            n_sample = 0
-            tokens_output = [self.id2token[i.item()] for i in result[n_sample]]
+            tokens_output = [self.id2token[i.item()] for i in result[0]]
             tokens_input = ["n/a"] + tokens_output[:-1]
-            for n_head in range(attention[0].size(1)):
+            for n_head in range(ram[0].size(1)):
                 filename = f"rpl_attention_{n_epoch}_h{n_head}.pdf"
+                attention_matrices = [m[0, n_head] for m in ram]
                 save_attention_image(
                     filename,
                     tokens_input,
                     tokens_output,
-                    attention,
-                    n_sample=n_sample,
-                    n_head=n_head,
+                    attention_matrices,
                     token_gap=12,
-                    layer_gap=40,
-                    # k_top=2,
+                    layer_gap=50,
+                    k_top=10,
+                    # min_total_attention=0.9,
                 )
                 logger(f"wrote {filename}")