######################################################################
def save_attention_image(
filename,
- tokens,
+ tokens_input,
+ tokens_output,
attention,
+ n_sample=0,
+ n_head=0,
pixel_scale=8,
token_gap=10,
layer_gap=25,
- y_eps=1.5,
- padding=0,
- min_att=1e-2,
+ y_eps=0.5,
+ padding=10,
+ min_att=0,
+ k_top=None,
):
- # surface = cairo.PDFSurface(
- # filename, surface_width * pixel_scale, surface_height * pixel_scale
- # )
+ attention = torch.cat(
+ [x[n_sample : n_sample + 1, n_head] for x in attention], dim=0
+ )
+
+ if k_top is not None:
+ attention = attention * (
+ attention.sort(dim=-1, descending=True).indices < k_top
+ )
surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
x, y = 0, 0
- u = {}
- for n, t in enumerate(tokens):
- string = str(t)
+ for d in range(attention.size(0)):
+ if d > 0:
+ for n in range(attention.size(-1)):
+ xc, yc = n * token_gap, -d * layer_gap
+ ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
+ ctx.fill()
+
+ at = attention[d]
+ ni = torch.arange(at.size(0))[:, None].expand_as(at)
+ nj = torch.arange(at.size(1))[None, :].expand_as(at)
+ at = at.flatten()
+ o = at.sort().indices
+ at = at[o]
+ ni = ni.flatten()[o]
+ nj = nj.flatten()[o]
+ for i, j, a in zip(ni, nj, at):
+ if a > 0 and a >= min_att:
+ c = 1 - a.item()
+ ctx.set_source_rgb(c, c, c)
+ ctx.set_line_width(0.5)
+ ctx.move_to(j * token_gap, y - y_eps)
+ ctx.line_to(i * token_gap, y - layer_gap + y_eps)
+ ctx.stroke()
+ y -= layer_gap
+
+ for d in range(1, attention.size(0)):
+ for n in range(attention.size(-1)):
+ xc, yc = n * token_gap, -d * layer_gap
+ ctx.set_source_rgb(1.0, 1.0, 1.0)
+ ctx.arc(xc, yc, token_gap / 10 + 0.5, 0, 2 * math.pi)
+ ctx.fill()
+ ctx.set_source_rgb(0.0, 0.0, 0.0)
+ ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
+ ctx.fill()
+
+ ctx.set_source_rgb(0.0, 0.0, 0.0)
+
+ for k, t in enumerate(tokens_input):
+ s = str(t)
(
x_bearing,
y_bearing,
height_t,
x_advance,
y_advance,
- ) = ctx.text_extents(string)
- u[n]=(string, x, x + width_t / 2, height_t, y_bearing)
- x += x_advance + token_gap
- tokens = u
-
- for d in range(attention.size(0) + 1):
- for n, (s, x, xc, h, yb) in tokens.items():
- if d < attention.size(0):
- for m, (_, _, x2c, h2, y2b) in tokens.items():
- if attention[d, n, m] >= min_att:
- c = 1 - attention[d, n, m]
- ctx.set_source_rgb(c, c, c)
- ctx.set_line_width(0.5)
- ctx.move_to(xc, y + yb + h + y_eps)
- ctx.line_to(x2c, y + layer_gap + y2b - y_eps)
- ctx.stroke()
- # ctx.set_source_rgb(0.0, 0.0, 0.0)
- # ctx.rectangle(x+x_bearing,y+y_bearing,width_t,height_t)
- # ctx.stroke()
- ctx.set_source_rgb(0.0, 0.0, 0.0)
- ctx.move_to(x, y)
- ctx.show_text(s)
- # x += x_advance + 1
- y += layer_gap
+ ) = ctx.text_extents(s)
+ ctx.move_to(k * token_gap - width_t / 2, -y_bearing)
+ ctx.show_text(s)
+
+ for k, t in enumerate(tokens_output):
+ s = str(t)
+ (
+ x_bearing,
+ y_bearing,
+ width_t,
+ height_t,
+ x_advance,
+ y_advance,
+ ) = ctx.text_extents(s)
+ ctx.move_to(k * token_gap - width_t / 2, -attention.size(0) * layer_gap)
+ ctx.show_text(s)
x, y, width, height = surface.ink_extents()
x -= padding
if __name__ == "__main__":
import mygpt
+ tokens_output = ["bluh", 2, 3, 4, "blih"]
+ tokens_input = ["n/a"] + tokens_output[:-1]
+
vocabulary_size = 3
- x = torch.randint(vocabulary_size, (1, 5))
+ x = torch.randint(vocabulary_size, (1, len(tokens_input)))
model = mygpt.MyGPT(
vocabulary_size=vocabulary_size,
y1 = model(mygpt.BracketedSequence(x)).x
- a = model.retrieve_attention()
- print(a)
- attention = torch.cat([x[:0] for x in a], dim=0)
-
- tokens = ["bluh", 2, 3, 4, "blih"]
- attention = torch.randn(3, len(tokens), len(tokens)).softmax(dim=-1)
+ attention = model.retrieve_attention()
- save_attention_image("attention.pdf", tokens, attention, padding=3)
+ save_attention_image("attention.pdf", tokens_input, tokens_output, attention)
##############################
# rpl options
-parser.add_argument("--rpl-nb_starting_values", type=int, default=5)
+parser.add_argument("--rpl_nb_starting_values", type=int, default=5)
-parser.add_argument("--rpl-max_input", type=int, default=9)
+parser.add_argument("--rpl_max_input", type=int, default=9)
-parser.add_argument("--rpl-prog_len", type=int, default=10)
+parser.add_argument("--rpl_prog_len", type=int, default=10)
-parser.add_argument("--rpl-nb_runs", type=int, default=8)
+parser.add_argument("--rpl_nb_runs", type=int, default=8)
-parser.add_argument("--rpl-no-prog", action="store_true", default=False)
+parser.add_argument("--rpl_no_prog", action="store_true", default=False)
##############################
# sandbox options
if args.task == "expr" and args.expr_input_file is not None:
task.produce_results(
- nb_epochs_finished,
- model,
- args.result_dir,
- log_string,
- args.deterministic_synthesis,
- args.expr_input_file,
+ n_epoch=nb_epochs_finished,
+ model=model,
+ result_dir=args.result_dir,
+ logger=log_string,
+ deterministic_synthesis=args.deterministic_synthesis,
+ input_file=args.expr_input_file,
)
exit(0)
if nb_epochs_finished >= nb_epochs:
task.produce_results(
- nb_epochs_finished,
- model,
- args.result_dir,
- log_string,
- args.deterministic_synthesis,
+ n_epoch=nb_epochs_finished,
+ model=model,
+ result_dir=args.result_dir,
+ logger=log_string,
+ deterministic_synthesis=args.deterministic_synthesis,
)
for n_epoch in range(nb_epochs_finished, nb_epochs):
)
task.produce_results(
- n_epoch, model, args.result_dir, log_string, args.deterministic_synthesis
+ n_epoch=n_epoch,
+ model=model,
+ result_dir=args.result_dir,
+ logger=log_string,
+ deterministic_synthesis=args.deterministic_synthesis,
)
checkpoint = {
"nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
) / math.sqrt(self.w_q.size(1))
- if self.record_attention:
- self.a = a
-
if self.causal:
if bs_q.first == 0:
self.cache_attzero = (
)
a = a.softmax(dim=3)
+
+ if self.record_attention:
+ self.a = a
+
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum(
result_stack = rpl_exec(prog, stack)
if len(result_stack) == 0:
no_empty_stack = False
- result = result + ["<input>"] + stack + ["<output>"] + result_stack
+ result = result + ["<in>"] + stack + ["<out>"] + result_stack
- result = result + ["<prog>"] + prog
+ result = result + ["<prg>"] + prog
result = result + ["<end>"]
if no_empty_stack and (
def decompose(seq):
io = []
k = 0
- while seq[k] == "<input>":
- o = next_marker(seq, ["<output>"], start=k + 1)
+ while seq[k] == "<in>":
+ o = next_marker(seq, ["<out>"], start=k + 1)
if o is None:
raise ValueError("Missing output markers (should be correct in the prompt)")
- e = next_marker(seq, ["<input>", "<prog>"], start=o)
+ e = next_marker(seq, ["<in>", "<prg>"], start=o)
if e is None:
raise ValueError(
"Missing input/output markers (should be correct in the prompt)"
k = e
- if seq[k] == "<prog>":
+ if seq[k] == "<prg>":
e = next_marker(seq, ["<end>"], start=k)
if e is None:
prog = []
else:
prog = seq[k + 1 : e]
else:
- raise ValueError("Missing <prog> (it should be in the prompt)")
+ raise ValueError("Missing <prg> (it should be in the prompt)")
return prog, io
from torch import nn
from torch.nn import functional as F
+from mygpt import BracketedSequence
+
+try:
+ from graph import save_attention_image
+except ImportError:
+ save_attention_image = None
+
######################################################################
self.id2token = dict([(n, c) for c, n in self.token2id.items()])
self.t_nul = self.token2id["<nul>"]
- self.t_input = self.token2id["<input>"]
- self.t_output = self.token2id["<output>"]
- self.t_prog = self.token2id["<prog>"]
+ self.t_input = self.token2id["<in>"]
+ self.t_output = self.token2id["<out>"]
+ self.t_prog = self.token2id["<prg>"]
self.t_end = self.token2id["<end>"]
self.train_input = self.tensorize(train_sequences)
f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
)
+ if save_attention_image is not None:
+ input = self.test_input[:10]
+ result = input.clone()
+ s = (result == self.t_prog).long()
+ ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
+ result = (1 - ar_mask) * result + ar_mask * self.t_nul
+
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+ model.record_attention(True)
+ model(BracketedSequence(result))
+ model.train(t)
+ attention = model.retrieve_attention()
+ model.record_attention(False)
+
+ n_sample = 0
+ tokens_output = [self.id2token[i.item()] for i in result[n_sample]]
+ tokens_input = ["n/a"] + tokens_output[:-1]
+ for n_head in range(attention[0].size(1)):
+ filename = f"rpl_attention_{n_epoch}_h{n_head}.pdf"
+ save_attention_image(
+ filename,
+ tokens_input,
+ tokens_output,
+ attention,
+ n_sample=n_sample,
+ n_head=n_head,
+ token_gap=12,
+ layer_gap=40,
+ # k_top=2,
+ )
+ logger(f"wrote {filename}")
+
######################################################################