--- /dev/null
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+import cairo
+
+
+######################################################################
+def save_attention_image(
+ filename,
+ tokens,
+ attention,
+ surface_width=128,
+ surface_height=96,
+ pixel_scale=8,
+ x=10,
+ y=10,
+ token_gap=15,
+ layer_gap=25,
+ y_eps=1,
+ min_att=1e-2,
+):
+ # surface = cairo.PDFSurface(
+ # filename, surface_width * pixel_scale, surface_height * pixel_scale
+ # )
+
+ surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
+
+ ctx = cairo.Context(surface)
+ ctx.scale(pixel_scale, pixel_scale)
+
+ ctx.set_source_rgb(0.0, 0.0, 0.0)
+ ctx.set_font_size(4.0)
+ # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
+
+ u = []
+ for n, t in enumerate(tokens):
+ string = str(t)
+ (
+ x_bearing,
+ y_bearing,
+ width_t,
+ height_t,
+ x_advance,
+ y_advance,
+ ) = ctx.text_extents(string)
+ u.append((n, string, x, x + width_t / 2, height_t, y_bearing))
+ x += x_advance + token_gap
+ tokens = u
+
+ for d in range(attention.size(0) + 1):
+ for n, s, x, xc, h, yb in tokens:
+ # ctx.set_source_rgb(0.0, 0.0, 0.0)
+ # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t)
+ # ctx.stroke()
+ ctx.set_source_rgb(0.0, 0.0, 0.0)
+ ctx.move_to(x, y)
+ ctx.show_text(s)
+ # x += x_advance + 1
+ if d < attention.size(0):
+ for m, _, _, x2c, h2, y2b in tokens:
+ if attention[d, n, m] >= min_att:
+ c = 1 - attention[d, n, m]
+ ctx.set_source_rgb(c, c, c)
+ ctx.set_line_width(0.5)
+ ctx.move_to(xc, y + yb + h + y_eps)
+ ctx.line_to(x2c, y + layer_gap + y2b - y_eps)
+ ctx.stroke()
+ y += layer_gap
+
+ x, y, width, height = surface.ink_extents()
+ pdf_surface = cairo.PDFSurface(filename, width, height)
+ ctx_pdf = cairo.Context(pdf_surface)
+ ctx_pdf.set_source_surface(surface, -x, -y)
+ ctx_pdf.paint()
+ pdf_surface.finish()
+
+
+######################################################################
+
+if __name__ == "__main__":
+ import mygpt
+
+ vocabulary_size = 3
+ x = torch.randint(vocabulary_size, (1, 5))
+
+ model = mygpt.MyGPT(
+ vocabulary_size=vocabulary_size,
+ dim_model=4,
+ dim_keys=2,
+ dim_hidden=2,
+ nb_heads=2,
+ nb_blocks=3,
+ dropout=0.1,
+ causal=True,
+ )
+
+ model.eval()
+ model.record_attention()
+
+ y1 = model(mygpt.BracketedSequence(x)).x
+
+ a = model.retrieve_attention()
+ print(a)
+ attention = torch.cat([x[:0] for x in a], dim=0)
+
+ tokens = ["bluh", 2, 3, 4, "blih"]
+ attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1)
+
+ save_attention_image("attention.pdf", tokens, attention)
class QKVAttention(nn.Module):
def __init__(
- self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+ self,
+ dim_in,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ causal=False,
+ attention_dropout=0.0,
):
super().__init__()
self.causal = causal
self.attention_dropout = attention_dropout
+ self.record_attention = False
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
"nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
) / math.sqrt(self.w_q.size(1))
+ if self.record_attention:
+ self.a = a
+
if self.causal:
if bs_q.first == 0:
self.cache_attzero = (
t_next = dist.sample()
input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+ def record_attention(self, v=True):
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ m.record_attention = v
+
+ def retrieve_attention(self):
+ a = []
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ a.append(m.a)
+ return a
+
######################################################################
dim_keys=2,
dim_hidden=2,
nb_heads=2,
- nb_blocks=1,
+ nb_blocks=2,
dropout=0.1,
causal=True,
)
model.eval()
-
y1 = model(BracketedSequence(x)).x
y2 = torch.randn_like(y1)
for s in range(x.size(1)):
self.test_input = self.tensorize(test_sequences)
if no_prog:
+ # Excise the program from every train and test example
k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
None, :
]
)
sum_nb_total, sum_nb_errors = 0, 0
- for x, y in zip(input, result):
- seq = [self.id2token[i.item()] for i in y]
+ for one_input, one_result in zip(input, result):
+ seq = [self.id2token[i.item()] for i in one_result]
nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
sum_nb_total += 1
sum_nb_errors += 0 if nb_errors == 0 else 1
if nb_to_log > 0:
- gt_seq = [self.id2token[i.item()] for i in x]
+ gt_seq = [self.id2token[i.item()] for i in one_input]
_, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
gt_prog = " ".join([str(x) for x in gt_prog])
prog = " ".join([str(x) for x in prog])
)
sum_nb_total, sum_nb_errors = 0, 0
- for x, y, i, j in zip(input, result, last_output_idx, first_prog_idx):
- seq = [self.id2token[i.item()] for i in y]
+ for one_input, one_result, i, j in zip(
+ input, result, last_output_idx, first_prog_idx
+ ):
+ seq = [self.id2token[i.item()] for i in one_result]
sum_nb_total += 1
- correct = (x - y).abs().max() == 0
+ correct = (one_input - one_result).abs().max() == 0
sum_nb_errors += 0 if correct else 1
if nb_to_log > 0:
- result_stack = [self.id2token[i.item()] for i in y[i : j + 1]]
- target_stack = [self.id2token[i.item()] for i in x[i : j + 1]]
+ result_stack = [
+ self.id2token[i.item()] for i in one_result[i : j + 1]
+ ]
+ target_stack = [
+ self.id2token[i.item()] for i in one_input[i : j + 1]
+ ]
comment = "*" if correct else "-"
result_stack = " ".join([str(x) for x in result_stack])
target_stack = " ".join([str(x) for x in target_stack])