from torch import nn
from torch.nn import functional as F
+from mygpt import BracketedSequence
+
+try:
+ from graph import save_attention_image
+except ImportError:
+ save_attention_image = None
+
######################################################################
self.id2token = dict([(n, c) for c, n in self.token2id.items()])
self.t_nul = self.token2id["<nul>"]
- self.t_input = self.token2id["<input>"]
- self.t_output = self.token2id["<output>"]
- self.t_prog = self.token2id["<prog>"]
+ self.t_input = self.token2id["<in>"]
+ self.t_output = self.token2id["<out>"]
+ self.t_prog = self.token2id["<prg>"]
self.t_end = self.token2id["<end>"]
self.train_input = self.tensorize(train_sequences)
f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
)
+ if save_attention_image is not None:
+ input = self.test_input[:10]
+ result = input.clone()
+ s = (result == self.t_prog).long()
+ ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
+ result = (1 - ar_mask) * result + ar_mask * self.t_nul
+
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+ model.record_attention(True)
+ model(BracketedSequence(result))
+ model.train(t)
+ attention = model.retrieve_attention()
+ model.record_attention(False)
+
+ n_sample = 0
+ tokens_output = [self.id2token[i.item()] for i in result[n_sample]]
+ tokens_input = ["n/a"] + tokens_output[:-1]
+ for n_head in range(attention[0].size(1)):
+ filename = f"rpl_attention_{n_epoch}_h{n_head}.pdf"
+ save_attention_image(
+ filename,
+ tokens_input,
+ tokens_output,
+ attention,
+ n_sample=n_sample,
+ n_head=n_head,
+ token_gap=12,
+ layer_gap=40,
+ # k_top=2,
+ )
+ logger(f"wrote {filename}")
+
######################################################################