From 00b2d5ed01fb523fbc4e699f0419329efbee0ea8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jul 2023 10:06:45 +0200 Subject: [PATCH] Update. --- graph.py | 115 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ mygpt.py | 27 +++++++++++-- tasks.py | 23 +++++++---- 3 files changed, 154 insertions(+), 11 deletions(-) create mode 100755 graph.py diff --git a/graph.py b/graph.py new file mode 100755 index 0000000..97de6d1 --- /dev/null +++ b/graph.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +import cairo + + +###################################################################### +def save_attention_image( + filename, + tokens, + attention, + surface_width=128, + surface_height=96, + pixel_scale=8, + x=10, + y=10, + token_gap=15, + layer_gap=25, + y_eps=1, + min_att=1e-2, +): + # surface = cairo.PDFSurface( + # filename, surface_width * pixel_scale, surface_height * pixel_scale + # ) + + surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None) + + ctx = cairo.Context(surface) + ctx.scale(pixel_scale, pixel_scale) + + ctx.set_source_rgb(0.0, 0.0, 0.0) + ctx.set_font_size(4.0) + # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL) + + u = [] + for n, t in enumerate(tokens): + string = str(t) + ( + x_bearing, + y_bearing, + width_t, + height_t, + x_advance, + y_advance, + ) = ctx.text_extents(string) + u.append((n, string, x, x + width_t / 2, height_t, y_bearing)) + x += x_advance + token_gap + tokens = u + + for d in range(attention.size(0) + 1): + for n, s, x, xc, h, yb in tokens: + # ctx.set_source_rgb(0.0, 0.0, 0.0) + # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t) + # ctx.stroke() + ctx.set_source_rgb(0.0, 0.0, 0.0) + ctx.move_to(x, y) + ctx.show_text(s) + # x += x_advance + 1 + if d < attention.size(0): + for m, _, _, x2c, h2, y2b in tokens: + if attention[d, n, m] >= min_att: + c = 1 - attention[d, n, m] + ctx.set_source_rgb(c, c, c) + ctx.set_line_width(0.5) + ctx.move_to(xc, y + yb + h + y_eps) + ctx.line_to(x2c, y + layer_gap + y2b - y_eps) + ctx.stroke() + y += layer_gap + + x, y, width, height = surface.ink_extents() + pdf_surface = cairo.PDFSurface(filename, width, height) + ctx_pdf = cairo.Context(pdf_surface) + ctx_pdf.set_source_surface(surface, -x, -y) + ctx_pdf.paint() + pdf_surface.finish() + + +###################################################################### + +if __name__ == "__main__": + import mygpt + + vocabulary_size = 3 + x = torch.randint(vocabulary_size, (1, 5)) + + model = mygpt.MyGPT( + vocabulary_size=vocabulary_size, + dim_model=4, + dim_keys=2, + dim_hidden=2, + nb_heads=2, + nb_blocks=3, + dropout=0.1, + causal=True, + ) + + model.eval() + model.record_attention() + + y1 = model(mygpt.BracketedSequence(x)).x + + a = model.retrieve_attention() + print(a) + attention = torch.cat([x[:0] for x in a], dim=0) + + tokens = ["bluh", 2, 3, 4, "blih"] + attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1) + + save_attention_image("attention.pdf", tokens, attention) diff --git a/mygpt.py b/mygpt.py index 45b7b59..ac1c55e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -116,7 +116,13 @@ class AddPositionalEncoding(nn.Module): class QKVAttention(nn.Module): def __init__( - self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0 + self, + dim_in, + dim_qk, + dim_v, + nb_heads=1, + causal=False, + attention_dropout=0.0, ): super().__init__() @@ -125,6 +131,7 @@ class QKVAttention(nn.Module): self.causal = causal self.attention_dropout = attention_dropout + self.record_attention = False self.w_q = randw(nb_heads, dim_qk, dim_in) self.w_k = randw(nb_heads, dim_qk, dim_in) @@ -162,6 +169,9 @@ class QKVAttention(nn.Module): "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb] ) / math.sqrt(self.w_q.size(1)) + if self.record_attention: + self.a = a + if self.causal: if bs_q.first == 0: self.cache_attzero = ( @@ -283,6 +293,18 @@ class MyGPT(nn.Module): t_next = dist.sample() input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] + def record_attention(self, v=True): + for m in self.modules(): + if isinstance(m, QKVAttention): + m.record_attention = v + + def retrieve_attention(self): + a = [] + for m in self.modules(): + if isinstance(m, QKVAttention): + a.append(m.a) + return a + ###################################################################### @@ -298,13 +320,12 @@ if __name__ == "__main__": dim_keys=2, dim_hidden=2, nb_heads=2, - nb_blocks=1, + nb_blocks=2, dropout=0.1, causal=True, ) model.eval() - y1 = model(BracketedSequence(x)).x y2 = torch.randn_like(y1) for s in range(x.size(1)): diff --git a/tasks.py b/tasks.py index ca71182..af71b85 100755 --- a/tasks.py +++ b/tasks.py @@ -1111,6 +1111,7 @@ class RPL(Task): self.test_input = self.tensorize(test_sequences) if no_prog: + # Excise the program from every train and test example k = torch.arange(self.train_input.size(1), device=self.train_input.device)[ None, : ] @@ -1185,13 +1186,13 @@ class RPL(Task): ) sum_nb_total, sum_nb_errors = 0, 0 - for x, y in zip(input, result): - seq = [self.id2token[i.item()] for i in y] + for one_input, one_result in zip(input, result): + seq = [self.id2token[i.item()] for i in one_result] nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq) sum_nb_total += 1 sum_nb_errors += 0 if nb_errors == 0 else 1 if nb_to_log > 0: - gt_seq = [self.id2token[i.item()] for i in x] + gt_seq = [self.id2token[i.item()] for i in one_input] _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq) gt_prog = " ".join([str(x) for x in gt_prog]) prog = " ".join([str(x) for x in prog]) @@ -1232,14 +1233,20 @@ class RPL(Task): ) sum_nb_total, sum_nb_errors = 0, 0 - for x, y, i, j in zip(input, result, last_output_idx, first_prog_idx): - seq = [self.id2token[i.item()] for i in y] + for one_input, one_result, i, j in zip( + input, result, last_output_idx, first_prog_idx + ): + seq = [self.id2token[i.item()] for i in one_result] sum_nb_total += 1 - correct = (x - y).abs().max() == 0 + correct = (one_input - one_result).abs().max() == 0 sum_nb_errors += 0 if correct else 1 if nb_to_log > 0: - result_stack = [self.id2token[i.item()] for i in y[i : j + 1]] - target_stack = [self.id2token[i.item()] for i in x[i : j + 1]] + result_stack = [ + self.id2token[i.item()] for i in one_result[i : j + 1] + ] + target_stack = [ + self.id2token[i.item()] for i in one_input[i : j + 1] + ] comment = "*" if correct else "-" result_stack = " ".join([str(x) for x in result_stack]) target_stack = " ".join([str(x) for x in target_stack]) -- 2.20.1