X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=graph.py;h=08f1170b9dfc4838340e972d8731ace834967246;hb=95717a8bf88159051f9c4b8862b0b643187826e9;hp=a819283cc28450c28b54a6e5b6b215c408c65243;hpb=3b9ba21fd3d06a20703216cc0a77fe9dc78b079f;p=picoclvr.git diff --git a/graph.py b/graph.py index a819283..08f1170 100755 --- a/graph.py +++ b/graph.py @@ -14,16 +14,10 @@ import cairo def save_attention_image( - filename, + filename, # image to save tokens_input, tokens_output, - # An iterable set of BxHxTxT attention matrices - attention_matrices, - pixel_scale=8, - token_gap=15, - layer_gap=25, - y_eps=0.5, - padding=10, + attention_matrices, # list of 2d tensors T1xT2, T2xT3, ..., Tk-1xTk # do not draw links with a lesser attention min_link_attention=0, # draw only the strongest links necessary to reache @@ -32,6 +26,11 @@ def save_attention_image( # draw only the top k links k_top=None, curved=True, + pixel_scale=8, + token_gap=15, + layer_gap=25, + y_eps=0.5, + padding=10, ): if k_top is not None: am = []