X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=graph.py;h=2c7caf801acf48489228dfaa1661de459f4240ea;hb=291c38d093894d46fba6eb45f82e5b65a2a1cb8b;hp=c286388d6298eba52f7ad88ed9bffa0d5ab2af9a;hpb=8492656cf0cc5de4f7e2c4aa8ccb717193293b40;p=picoclvr.git diff --git a/graph.py b/graph.py index c286388..2c7caf8 100755 --- a/graph.py +++ b/graph.py @@ -14,24 +14,23 @@ import cairo def save_attention_image( - filename, + filename, # image to save tokens_input, tokens_output, - # An iterable set of BxHxTxT attention matrices - attention_matrices, - pixel_scale=8, - token_gap=15, - layer_gap=25, - y_eps=0.5, - padding=10, + attention_matrices, # list of 2d tensors T1xT2, T2xT3, ..., Tk-1xTk # do not draw links with a lesser attention min_link_attention=0, - # draw only the strongest links necessary to reache - # min_total_attention + # draw only the strongest links necessary so that their summed + # attention is above min_total_attention min_total_attention=None, # draw only the top k links k_top=None, curved=True, + pixel_scale=8, + token_gap=15, + layer_gap=25, + y_eps=0.5, + padding=10, ): if k_top is not None: am = [] @@ -60,7 +59,7 @@ def save_attention_image( ctx.set_line_width(0.25) for d in range(len(attention_matrices)): - at = attention_matrices[d] + at = attention_matrices[d].to("cpu") ni = torch.arange(at.size(0))[:, None].expand_as(at) nj = torch.arange(at.size(1))[None, :].expand_as(at) at = at.flatten() @@ -161,7 +160,7 @@ if __name__ == "__main__": nb_heads=2, nb_blocks=5, dropout=0.1, - #causal=True, + causal=True, ) model.eval() @@ -171,8 +170,6 @@ if __name__ == "__main__": attention_matrices = [m[0, 0] for m in model.retrieve_attention()] - - # attention_matrices = [ torch.rand(3,5), torch.rand(8,3), torch.rand(5,8) ] # for a in attention_matrices: a=a/a.sum(-1,keepdim=True)