X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=graph.py;h=07e376a3af86d24874523775e26f16132a489512;hb=0b8185b90014369f0d39892e128ad04a7d9ae872;hp=bd801875bea7531fb0a480d609ff408f887956d6;hpb=7aef882f33f5ca180a9a9c11c5aab8ce0f099685;p=picoclvr.git diff --git a/graph.py b/graph.py index bd80187..07e376a 100755 --- a/graph.py +++ b/graph.py @@ -14,10 +14,12 @@ import cairo def save_attention_image( - filename, # image to save + # image to save + filename, tokens_input, tokens_output, - attention_matrices, # list of 2d tensors T1xT2, T2xT3, ..., Tk-1xTk + # list of 2d tensors T2xT1, T3xT2, ..., TkxTk-1 + attention_matrices, # do not draw links with a lesser attention min_link_attention=0, # draw only the strongest links necessary so that their summed @@ -25,6 +27,7 @@ def save_attention_image( min_total_attention=None, # draw only the top k links k_top=None, + # the purely graphical settings curved=True, pixel_scale=8, token_gap=15, @@ -59,7 +62,7 @@ def save_attention_image( ctx.set_line_width(0.25) for d in range(len(attention_matrices)): - at = attention_matrices[d] + at = attention_matrices[d].to("cpu") ni = torch.arange(at.size(0))[:, None].expand_as(at) nj = torch.arange(at.size(1))[None, :].expand_as(at) at = at.flatten() @@ -110,7 +113,7 @@ def save_attention_image( x_advance, y_advance, ) = ctx.text_extents(s) - ctx.move_to(k * token_gap - width_t / 2, token_gap / 5 - y_bearing) + ctx.move_to(k * token_gap - width_t / 2, 2 * token_gap / 5) ctx.show_text(s) for k, t in enumerate(tokens_output): @@ -146,7 +149,7 @@ def save_attention_image( if __name__ == "__main__": import mygpt - tokens_output = ["", 2, 3, 4, ""] + tokens_output = ["", "-", 3, 4, ""] tokens_input = [""] + tokens_output[:-1] vocabulary_size = 3 @@ -170,8 +173,7 @@ if __name__ == "__main__": attention_matrices = [m[0, 0] for m in model.retrieve_attention()] - # attention_matrices = [ torch.rand(3,5), torch.rand(8,3), torch.rand(5,8) ] - # for a in attention_matrices: a=a/a.sum(-1,keepdim=True) + # attention_matrices = [torch.rand(*s) for s in [ (4,5),(3,4),(8,3),(5,8) ]] save_attention_image( "attention.pdf",