X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=0143ab20058de577a590f16ccfda2093adf174e4;hb=687d5b2d9f465577665991b84faec7c789685271;hp=421aee49f2f0005f7650ea9836fde801fc5598e8;hpb=db7cefe4fefb381e56f1292d5bbe4a18c76afb47;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 421aee4..0143ab2 100755 --- a/tasks.py +++ b/tasks.py @@ -76,6 +76,7 @@ class Task: import problems + class SandBox(Task): def __init__( self, @@ -180,6 +181,39 @@ class SandBox(Task): f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" ) + if save_attention_image is not None: + for k in range(10): + ns = torch.randint(self.test_input.size(0), (1,)).item() + input = self.test_input[ns : ns + 1].clone() + + with torch.autograd.no_grad(): + t = model.training + model.eval() + model.record_attention(True) + model(BracketedSequence(input)) + model.train(t) + ram = model.retrieve_attention() + model.record_attention(False) + + tokens_output = [c for c in self.problem.seq2str(input[0])] + tokens_input = ["n/a"] + tokens_output[:-1] + for n_head in range(ram[0].size(1)): + filename = os.path.join( + result_dir, f"sandbox_attention_{k}_h{n_head}.pdf" + ) + attention_matrices = [m[0, n_head] for m in ram] + save_attention_image( + filename, + tokens_input, + tokens_output, + attention_matrices, + k_top=10, + # min_total_attention=0.9, + token_gap=12, + layer_gap=50, + ) + logger(f"wrote {filename}") + ###################################################################### @@ -1134,8 +1168,8 @@ class RPL(Task): ) if save_attention_image is not None: - ns=torch.randint(self.test_input.size(0),(1,)).item() - input = self.test_input[ns:ns+1].clone() + ns = torch.randint(self.test_input.size(0), (1,)).item() + input = self.test_input[ns : ns + 1].clone() last = (input != self.t_nul).max(0).values.nonzero().max() + 3 input = input[:, :last].to(self.device)