X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=0143ab20058de577a590f16ccfda2093adf174e4;hb=687d5b2d9f465577665991b84faec7c789685271;hp=b2f7d7dc5f750610333d03b7da6c183d83ff7a7a;hpb=16cb07f99cf770fb4e97824f874a68cbddd4c1cf;p=picoclvr.git diff --git a/tasks.py b/tasks.py index b2f7d7d..0143ab2 100755 --- a/tasks.py +++ b/tasks.py @@ -181,6 +181,39 @@ class SandBox(Task): f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" ) + if save_attention_image is not None: + for k in range(10): + ns = torch.randint(self.test_input.size(0), (1,)).item() + input = self.test_input[ns : ns + 1].clone() + + with torch.autograd.no_grad(): + t = model.training + model.eval() + model.record_attention(True) + model(BracketedSequence(input)) + model.train(t) + ram = model.retrieve_attention() + model.record_attention(False) + + tokens_output = [c for c in self.problem.seq2str(input[0])] + tokens_input = ["n/a"] + tokens_output[:-1] + for n_head in range(ram[0].size(1)): + filename = os.path.join( + result_dir, f"sandbox_attention_{k}_h{n_head}.pdf" + ) + attention_matrices = [m[0, n_head] for m in ram] + save_attention_image( + filename, + tokens_input, + tokens_output, + attention_matrices, + k_top=10, + # min_total_attention=0.9, + token_gap=12, + layer_gap=50, + ) + logger(f"wrote {filename}") + ######################################################################