Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jul 2023 17:00:25 +0000 (19:00 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jul 2023 17:00:25 +0000 (19:00 +0200)
graph.py
main.py
mygpt.py
rpl.py
tasks.py

index 0db7bd0..5bab861 100755 (executable)
--- a/graph.py
+++ b/graph.py
@@ -13,18 +13,27 @@ import cairo
 ######################################################################
 def save_attention_image(
     filename,
-    tokens,
+    tokens_input,
+    tokens_output,
     attention,
+    n_sample=0,
+    n_head=0,
     pixel_scale=8,
     token_gap=10,
     layer_gap=25,
-    y_eps=1.5,
-    padding=0,
-    min_att=1e-2,
+    y_eps=0.5,
+    padding=10,
+    min_att=0,
+    k_top=None,
 ):
-    # surface = cairo.PDFSurface(
-    # filename, surface_width * pixel_scale, surface_height * pixel_scale
-    # )
+    attention = torch.cat(
+        [x[n_sample : n_sample + 1, n_head] for x in attention], dim=0
+    )
+
+    if k_top is not None:
+        attention = attention * (
+            attention.sort(dim=-1, descending=True).indices < k_top
+        )
 
     surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
 
@@ -37,9 +46,45 @@ def save_attention_image(
 
     x, y = 0, 0
 
-    u = {}
-    for n, t in enumerate(tokens):
-        string = str(t)
+    for d in range(attention.size(0)):
+        if d > 0:
+            for n in range(attention.size(-1)):
+                xc, yc = n * token_gap, -d * layer_gap
+                ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
+                ctx.fill()
+
+        at = attention[d]
+        ni = torch.arange(at.size(0))[:, None].expand_as(at)
+        nj = torch.arange(at.size(1))[None, :].expand_as(at)
+        at = at.flatten()
+        o = at.sort().indices
+        at = at[o]
+        ni = ni.flatten()[o]
+        nj = nj.flatten()[o]
+        for i, j, a in zip(ni, nj, at):
+            if a > 0 and a >= min_att:
+                c = 1 - a.item()
+                ctx.set_source_rgb(c, c, c)
+                ctx.set_line_width(0.5)
+                ctx.move_to(j * token_gap, y - y_eps)
+                ctx.line_to(i * token_gap, y - layer_gap + y_eps)
+                ctx.stroke()
+        y -= layer_gap
+
+    for d in range(1, attention.size(0)):
+        for n in range(attention.size(-1)):
+            xc, yc = n * token_gap, -d * layer_gap
+            ctx.set_source_rgb(1.0, 1.0, 1.0)
+            ctx.arc(xc, yc, token_gap / 10 + 0.5, 0, 2 * math.pi)
+            ctx.fill()
+            ctx.set_source_rgb(0.0, 0.0, 0.0)
+            ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
+            ctx.fill()
+
+    ctx.set_source_rgb(0.0, 0.0, 0.0)
+
+    for k, t in enumerate(tokens_input):
+        s = str(t)
         (
             x_bearing,
             y_bearing,
@@ -47,30 +92,22 @@ def save_attention_image(
             height_t,
             x_advance,
             y_advance,
-        ) = ctx.text_extents(string)
-        u[n]=(string, x, x + width_t / 2, height_t, y_bearing)
-        x += x_advance + token_gap
-    tokens = u
-
-    for d in range(attention.size(0) + 1):
-        for n, (s, x, xc, h, yb) in tokens.items():
-            if d < attention.size(0):
-                for m, (_, _, x2c, h2, y2b) in tokens.items():
-                    if attention[d, n, m] >= min_att:
-                        c = 1 - attention[d, n, m]
-                        ctx.set_source_rgb(c, c, c)
-                        ctx.set_line_width(0.5)
-                        ctx.move_to(xc, y + yb + h + y_eps)
-                        ctx.line_to(x2c, y + layer_gap + y2b - y_eps)
-                        ctx.stroke()
-            # ctx.set_source_rgb(0.0, 0.0, 0.0)
-            # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t)
-            # ctx.stroke()
-            ctx.set_source_rgb(0.0, 0.0, 0.0)
-            ctx.move_to(x, y)
-            ctx.show_text(s)
-            # x += x_advance + 1
-        y += layer_gap
+        ) = ctx.text_extents(s)
+        ctx.move_to(k * token_gap - width_t / 2, -y_bearing)
+        ctx.show_text(s)
+
+    for k, t in enumerate(tokens_output):
+        s = str(t)
+        (
+            x_bearing,
+            y_bearing,
+            width_t,
+            height_t,
+            x_advance,
+            y_advance,
+        ) = ctx.text_extents(s)
+        ctx.move_to(k * token_gap - width_t / 2, -attention.size(0) * layer_gap)
+        ctx.show_text(s)
 
     x, y, width, height = surface.ink_extents()
     x -= padding
@@ -89,8 +126,11 @@ def save_attention_image(
 if __name__ == "__main__":
     import mygpt
 
+    tokens_output = ["bluh", 2, 3, 4, "blih"]
+    tokens_input = ["n/a"] + tokens_output[:-1]
+
     vocabulary_size = 3
-    x = torch.randint(vocabulary_size, (1, 5))
+    x = torch.randint(vocabulary_size, (1, len(tokens_input)))
 
     model = mygpt.MyGPT(
         vocabulary_size=vocabulary_size,
@@ -108,11 +148,6 @@ if __name__ == "__main__":
 
     y1 = model(mygpt.BracketedSequence(x)).x
 
-    a = model.retrieve_attention()
-    print(a)
-    attention = torch.cat([x[:0] for x in a], dim=0)
-
-    tokens = ["bluh", 2, 3, 4, "blih"]
-    attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1)
+    attention = model.retrieve_attention()
 
-    save_attention_image("attention.pdf", tokens, attention, padding=3)
+    save_attention_image("attention.pdf", tokens_input, tokens_output, attention)
diff --git a/main.py b/main.py
index 5506bbd..1b0d39a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -81,15 +81,15 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 ##############################
 # rpl options
 
-parser.add_argument("--rpl-nb_starting_values", type=int, default=5)
+parser.add_argument("--rpl_nb_starting_values", type=int, default=5)
 
-parser.add_argument("--rpl-max_input", type=int, default=9)
+parser.add_argument("--rpl_max_input", type=int, default=9)
 
-parser.add_argument("--rpl-prog_len", type=int, default=10)
+parser.add_argument("--rpl_prog_len", type=int, default=10)
 
-parser.add_argument("--rpl-nb_runs", type=int, default=8)
+parser.add_argument("--rpl_nb_runs", type=int, default=8)
 
-parser.add_argument("--rpl-no-prog", action="store_true", default=False)
+parser.add_argument("--rpl_no_prog", action="store_true", default=False)
 
 ##############################
 # sandbox options
@@ -518,12 +518,12 @@ else:
 
 if args.task == "expr" and args.expr_input_file is not None:
     task.produce_results(
-        nb_epochs_finished,
-        model,
-        args.result_dir,
-        log_string,
-        args.deterministic_synthesis,
-        args.expr_input_file,
+        n_epoch=nb_epochs_finished,
+        model=model,
+        result_dir=args.result_dir,
+        logger=log_string,
+        deterministic_synthesis=args.deterministic_synthesis,
+        input_file=args.expr_input_file,
     )
 
     exit(0)
@@ -599,11 +599,11 @@ nb_samples_seen = 0
 
 if nb_epochs_finished >= nb_epochs:
     task.produce_results(
-        nb_epochs_finished,
-        model,
-        args.result_dir,
-        log_string,
-        args.deterministic_synthesis,
+        n_epoch=nb_epochs_finished,
+        model=model,
+        result_dir=args.result_dir,
+        logger=log_string,
+        deterministic_synthesis=args.deterministic_synthesis,
     )
 
 for n_epoch in range(nb_epochs_finished, nb_epochs):
@@ -657,7 +657,11 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         )
 
         task.produce_results(
-            n_epoch, model, args.result_dir, log_string, args.deterministic_synthesis
+            n_epoch=n_epoch,
+            model=model,
+            result_dir=args.result_dir,
+            logger=log_string,
+            deterministic_synthesis=args.deterministic_synthesis,
         )
 
     checkpoint = {
index ac1c55e..0400b48 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -169,9 +169,6 @@ class QKVAttention(nn.Module):
             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
         ) / math.sqrt(self.w_q.size(1))
 
-        if self.record_attention:
-            self.a = a
-
         if self.causal:
             if bs_q.first == 0:
                 self.cache_attzero = (
@@ -186,6 +183,10 @@ class QKVAttention(nn.Module):
             )
 
         a = a.softmax(dim=3)
+
+        if self.record_attention:
+            self.a = a
+
         a = F.dropout(a, self.attention_dropout, self.training)
 
         y = torch.einsum(
diff --git a/rpl.py b/rpl.py
index b51edef..b848afa 100755 (executable)
--- a/rpl.py
+++ b/rpl.py
@@ -75,9 +75,9 @@ def generate(
             result_stack = rpl_exec(prog, stack)
             if len(result_stack) == 0:
                 no_empty_stack = False
-            result = result + ["<input>"] + stack + ["<output>"] + result_stack
+            result = result + ["<in>"] + stack + ["<out>"] + result_stack
 
-        result = result + ["<prog>"] + prog
+        result = result + ["<prg>"] + prog
         result = result + ["<end>"]
 
         if no_empty_stack and (
@@ -103,11 +103,11 @@ def next_marker(seq, tokens, start=0):
 def decompose(seq):
     io = []
     k = 0
-    while seq[k] == "<input>":
-        o = next_marker(seq, ["<output>"], start=k + 1)
+    while seq[k] == "<in>":
+        o = next_marker(seq, ["<out>"], start=k + 1)
         if o is None:
             raise ValueError("Missing output markers (should be correct in the prompt)")
-        e = next_marker(seq, ["<input>", "<prog>"], start=o)
+        e = next_marker(seq, ["<in>", "<prg>"], start=o)
         if e is None:
             raise ValueError(
                 "Missing input/output markers (should be correct in the prompt)"
@@ -123,14 +123,14 @@ def decompose(seq):
 
         k = e
 
-    if seq[k] == "<prog>":
+    if seq[k] == "<prg>":
         e = next_marker(seq, ["<end>"], start=k)
         if e is None:
             prog = []
         else:
             prog = seq[k + 1 : e]
     else:
-        raise ValueError("Missing <prog> (it should be in the prompt)")
+        raise ValueError("Missing <prg> (it should be in the prompt)")
 
     return prog, io
 
index af71b85..0eed2aa 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -12,6 +12,13 @@ import torch, torchvision
 from torch import nn
 from torch.nn import functional as F
 
+from mygpt import BracketedSequence
+
+try:
+    from graph import save_attention_image
+except ImportError:
+    save_attention_image = None
+
 ######################################################################
 
 
@@ -1102,9 +1109,9 @@ class RPL(Task):
         self.id2token = dict([(n, c) for c, n in self.token2id.items()])
 
         self.t_nul = self.token2id["<nul>"]
-        self.t_input = self.token2id["<input>"]
-        self.t_output = self.token2id["<output>"]
-        self.t_prog = self.token2id["<prog>"]
+        self.t_input = self.token2id["<in>"]
+        self.t_output = self.token2id["<out>"]
+        self.t_prog = self.token2id["<prg>"]
         self.t_end = self.token2id["<end>"]
 
         self.train_input = self.tensorize(train_sequences)
@@ -1276,6 +1283,49 @@ class RPL(Task):
             f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
         )
 
+        if save_attention_image is not None:
+            input = self.test_input[:10]
+            result = input.clone()
+            s = (result == self.t_prog).long()
+            ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
+            result = (1 - ar_mask) * result + ar_mask * self.t_nul
+
+            masked_inplace_autoregression(
+                model,
+                self.batch_size,
+                result,
+                ar_mask,
+                deterministic_synthesis,
+                device=self.device,
+            )
+
+            with torch.autograd.no_grad():
+                t = model.training
+                model.eval()
+                model.record_attention(True)
+                model(BracketedSequence(result))
+                model.train(t)
+                attention = model.retrieve_attention()
+                model.record_attention(False)
+
+            n_sample = 0
+            tokens_output = [self.id2token[i.item()] for i in result[n_sample]]
+            tokens_input = ["n/a"] + tokens_output[:-1]
+            for n_head in range(attention[0].size(1)):
+                filename = f"rpl_attention_{n_epoch}_h{n_head}.pdf"
+                save_attention_image(
+                    filename,
+                    tokens_input,
+                    tokens_output,
+                    attention,
+                    n_sample=n_sample,
+                    n_head=n_head,
+                    token_gap=12,
+                    layer_gap=40,
+                    # k_top=2,
+                )
+                logger(f"wrote {filename}")
+
 
 ######################################################################