X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=graph.py;h=6db9ed7bc53c5a6adc3a3c9047139feac4ed087e;hb=687d5b2d9f465577665991b84faec7c789685271;hp=2c7caf801acf48489228dfaa1661de459f4240ea;hpb=b718ef527d4bfb014a9ad564bb5199c7d0780aa9;p=picoclvr.git diff --git a/graph.py b/graph.py index 2c7caf8..6db9ed7 100755 --- a/graph.py +++ b/graph.py @@ -110,7 +110,7 @@ def save_attention_image( x_advance, y_advance, ) = ctx.text_extents(s) - ctx.move_to(k * token_gap - width_t / 2, token_gap / 5 - y_bearing) + ctx.move_to(k * token_gap - width_t / 2, 2 * token_gap / 5) ctx.show_text(s) for k, t in enumerate(tokens_output): @@ -146,7 +146,7 @@ def save_attention_image( if __name__ == "__main__": import mygpt - tokens_output = ["", 2, 3, 4, ""] + tokens_output = ["", "-", 3, 4, ""] tokens_input = [""] + tokens_output[:-1] vocabulary_size = 3