From ef3bef5253ff719953dfffff28d4122c19acdd77 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jul 2023 19:00:25 +0200 Subject: [PATCH] Update. --- graph.py | 119 +++++++++++++++++++++++++++++++++++-------------------- main.py | 38 ++++++++++-------- mygpt.py | 7 ++-- rpl.py | 14 +++---- tasks.py | 56 ++++++++++++++++++++++++-- 5 files changed, 162 insertions(+), 72 deletions(-) diff --git a/graph.py b/graph.py index 0db7bd0..5bab861 100755 --- a/graph.py +++ b/graph.py @@ -13,18 +13,27 @@ import cairo ###################################################################### def save_attention_image( filename, - tokens, + tokens_input, + tokens_output, attention, + n_sample=0, + n_head=0, pixel_scale=8, token_gap=10, layer_gap=25, - y_eps=1.5, - padding=0, - min_att=1e-2, + y_eps=0.5, + padding=10, + min_att=0, + k_top=None, ): - # surface = cairo.PDFSurface( - # filename, surface_width * pixel_scale, surface_height * pixel_scale - # ) + attention = torch.cat( + [x[n_sample : n_sample + 1, n_head] for x in attention], dim=0 + ) + + if k_top is not None: + attention = attention * ( + attention.sort(dim=-1, descending=True).indices < k_top + ) surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None) @@ -37,9 +46,45 @@ def save_attention_image( x, y = 0, 0 - u = {} - for n, t in enumerate(tokens): - string = str(t) + for d in range(attention.size(0)): + if d > 0: + for n in range(attention.size(-1)): + xc, yc = n * token_gap, -d * layer_gap + ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) + ctx.fill() + + at = attention[d] + ni = torch.arange(at.size(0))[:, None].expand_as(at) + nj = torch.arange(at.size(1))[None, :].expand_as(at) + at = at.flatten() + o = at.sort().indices + at = at[o] + ni = ni.flatten()[o] + nj = nj.flatten()[o] + for i, j, a in zip(ni, nj, at): + if a > 0 and a >= min_att: + c = 1 - a.item() + ctx.set_source_rgb(c, c, c) + ctx.set_line_width(0.5) + ctx.move_to(j * token_gap, y - y_eps) + ctx.line_to(i * token_gap, y - layer_gap + y_eps) + ctx.stroke() + y -= layer_gap + + for d in range(1, attention.size(0)): + for n in range(attention.size(-1)): + xc, yc = n * token_gap, -d * layer_gap + ctx.set_source_rgb(1.0, 1.0, 1.0) + ctx.arc(xc, yc, token_gap / 10 + 0.5, 0, 2 * math.pi) + ctx.fill() + ctx.set_source_rgb(0.0, 0.0, 0.0) + ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) + ctx.fill() + + ctx.set_source_rgb(0.0, 0.0, 0.0) + + for k, t in enumerate(tokens_input): + s = str(t) ( x_bearing, y_bearing, @@ -47,30 +92,22 @@ def save_attention_image( height_t, x_advance, y_advance, - ) = ctx.text_extents(string) - u[n]=(string, x, x + width_t / 2, height_t, y_bearing) - x += x_advance + token_gap - tokens = u - - for d in range(attention.size(0) + 1): - for n, (s, x, xc, h, yb) in tokens.items(): - if d < attention.size(0): - for m, (_, _, x2c, h2, y2b) in tokens.items(): - if attention[d, n, m] >= min_att: - c = 1 - attention[d, n, m] - ctx.set_source_rgb(c, c, c) - ctx.set_line_width(0.5) - ctx.move_to(xc, y + yb + h + y_eps) - ctx.line_to(x2c, y + layer_gap + y2b - y_eps) - ctx.stroke() - # ctx.set_source_rgb(0.0, 0.0, 0.0) - # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t) - # ctx.stroke() - ctx.set_source_rgb(0.0, 0.0, 0.0) - ctx.move_to(x, y) - ctx.show_text(s) - # x += x_advance + 1 - y += layer_gap + ) = ctx.text_extents(s) + ctx.move_to(k * token_gap - width_t / 2, -y_bearing) + ctx.show_text(s) + + for k, t in enumerate(tokens_output): + s = str(t) + ( + x_bearing, + y_bearing, + width_t, + height_t, + x_advance, + y_advance, + ) = ctx.text_extents(s) + ctx.move_to(k * token_gap - width_t / 2, -attention.size(0) * layer_gap) + ctx.show_text(s) x, y, width, height = surface.ink_extents() x -= padding @@ -89,8 +126,11 @@ def save_attention_image( if __name__ == "__main__": import mygpt + tokens_output = ["bluh", 2, 3, 4, "blih"] + tokens_input = ["n/a"] + tokens_output[:-1] + vocabulary_size = 3 - x = torch.randint(vocabulary_size, (1, 5)) + x = torch.randint(vocabulary_size, (1, len(tokens_input))) model = mygpt.MyGPT( vocabulary_size=vocabulary_size, @@ -108,11 +148,6 @@ if __name__ == "__main__": y1 = model(mygpt.BracketedSequence(x)).x - a = model.retrieve_attention() - print(a) - attention = torch.cat([x[:0] for x in a], dim=0) - - tokens = ["bluh", 2, 3, 4, "blih"] - attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1) + attention = model.retrieve_attention() - save_attention_image("attention.pdf", tokens, attention, padding=3) + save_attention_image("attention.pdf", tokens_input, tokens_output, attention) diff --git a/main.py b/main.py index 5506bbd..1b0d39a 100755 --- a/main.py +++ b/main.py @@ -81,15 +81,15 @@ parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## # rpl options -parser.add_argument("--rpl-nb_starting_values", type=int, default=5) +parser.add_argument("--rpl_nb_starting_values", type=int, default=5) -parser.add_argument("--rpl-max_input", type=int, default=9) +parser.add_argument("--rpl_max_input", type=int, default=9) -parser.add_argument("--rpl-prog_len", type=int, default=10) +parser.add_argument("--rpl_prog_len", type=int, default=10) -parser.add_argument("--rpl-nb_runs", type=int, default=8) +parser.add_argument("--rpl_nb_runs", type=int, default=8) -parser.add_argument("--rpl-no-prog", action="store_true", default=False) +parser.add_argument("--rpl_no_prog", action="store_true", default=False) ############################## # sandbox options @@ -518,12 +518,12 @@ else: if args.task == "expr" and args.expr_input_file is not None: task.produce_results( - nb_epochs_finished, - model, - args.result_dir, - log_string, - args.deterministic_synthesis, - args.expr_input_file, + n_epoch=nb_epochs_finished, + model=model, + result_dir=args.result_dir, + logger=log_string, + deterministic_synthesis=args.deterministic_synthesis, + input_file=args.expr_input_file, ) exit(0) @@ -599,11 +599,11 @@ nb_samples_seen = 0 if nb_epochs_finished >= nb_epochs: task.produce_results( - nb_epochs_finished, - model, - args.result_dir, - log_string, - args.deterministic_synthesis, + n_epoch=nb_epochs_finished, + model=model, + result_dir=args.result_dir, + logger=log_string, + deterministic_synthesis=args.deterministic_synthesis, ) for n_epoch in range(nb_epochs_finished, nb_epochs): @@ -657,7 +657,11 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): ) task.produce_results( - n_epoch, model, args.result_dir, log_string, args.deterministic_synthesis + n_epoch=n_epoch, + model=model, + result_dir=args.result_dir, + logger=log_string, + deterministic_synthesis=args.deterministic_synthesis, ) checkpoint = { diff --git a/mygpt.py b/mygpt.py index ac1c55e..0400b48 100755 --- a/mygpt.py +++ b/mygpt.py @@ -169,9 +169,6 @@ class QKVAttention(nn.Module): "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb] ) / math.sqrt(self.w_q.size(1)) - if self.record_attention: - self.a = a - if self.causal: if bs_q.first == 0: self.cache_attzero = ( @@ -186,6 +183,10 @@ class QKVAttention(nn.Module): ) a = a.softmax(dim=3) + + if self.record_attention: + self.a = a + a = F.dropout(a, self.attention_dropout, self.training) y = torch.einsum( diff --git a/rpl.py b/rpl.py index b51edef..b848afa 100755 --- a/rpl.py +++ b/rpl.py @@ -75,9 +75,9 @@ def generate( result_stack = rpl_exec(prog, stack) if len(result_stack) == 0: no_empty_stack = False - result = result + [""] + stack + [""] + result_stack + result = result + [""] + stack + [""] + result_stack - result = result + [""] + prog + result = result + [""] + prog result = result + [""] if no_empty_stack and ( @@ -103,11 +103,11 @@ def next_marker(seq, tokens, start=0): def decompose(seq): io = [] k = 0 - while seq[k] == "": - o = next_marker(seq, [""], start=k + 1) + while seq[k] == "": + o = next_marker(seq, [""], start=k + 1) if o is None: raise ValueError("Missing output markers (should be correct in the prompt)") - e = next_marker(seq, ["", ""], start=o) + e = next_marker(seq, ["", ""], start=o) if e is None: raise ValueError( "Missing input/output markers (should be correct in the prompt)" @@ -123,14 +123,14 @@ def decompose(seq): k = e - if seq[k] == "": + if seq[k] == "": e = next_marker(seq, [""], start=k) if e is None: prog = [] else: prog = seq[k + 1 : e] else: - raise ValueError("Missing (it should be in the prompt)") + raise ValueError("Missing (it should be in the prompt)") return prog, io diff --git a/tasks.py b/tasks.py index af71b85..0eed2aa 100755 --- a/tasks.py +++ b/tasks.py @@ -12,6 +12,13 @@ import torch, torchvision from torch import nn from torch.nn import functional as F +from mygpt import BracketedSequence + +try: + from graph import save_attention_image +except ImportError: + save_attention_image = None + ###################################################################### @@ -1102,9 +1109,9 @@ class RPL(Task): self.id2token = dict([(n, c) for c, n in self.token2id.items()]) self.t_nul = self.token2id[""] - self.t_input = self.token2id[""] - self.t_output = self.token2id[""] - self.t_prog = self.token2id[""] + self.t_input = self.token2id[""] + self.t_output = self.token2id[""] + self.t_prog = self.token2id[""] self.t_end = self.token2id[""] self.train_input = self.tensorize(train_sequences) @@ -1276,6 +1283,49 @@ class RPL(Task): f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%" ) + if save_attention_image is not None: + input = self.test_input[:10] + result = input.clone() + s = (result == self.t_prog).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) + result = (1 - ar_mask) * result + ar_mask * self.t_nul + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + with torch.autograd.no_grad(): + t = model.training + model.eval() + model.record_attention(True) + model(BracketedSequence(result)) + model.train(t) + attention = model.retrieve_attention() + model.record_attention(False) + + n_sample = 0 + tokens_output = [self.id2token[i.item()] for i in result[n_sample]] + tokens_input = ["n/a"] + tokens_output[:-1] + for n_head in range(attention[0].size(1)): + filename = f"rpl_attention_{n_epoch}_h{n_head}.pdf" + save_attention_image( + filename, + tokens_input, + tokens_output, + attention, + n_sample=n_sample, + n_head=n_head, + token_gap=12, + layer_gap=40, + # k_top=2, + ) + logger(f"wrote {filename}") + ###################################################################### -- 2.39.5