X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=graph.py;h=2c7caf801acf48489228dfaa1661de459f4240ea;hb=6c8bed86221baae24a7c2aaaa41c009444efb5c9;hp=bd801875bea7531fb0a480d609ff408f887956d6;hpb=7aef882f33f5ca180a9a9c11c5aab8ce0f099685;p=picoclvr.git diff --git a/graph.py b/graph.py index bd80187..2c7caf8 100755 --- a/graph.py +++ b/graph.py @@ -59,7 +59,7 @@ def save_attention_image( ctx.set_line_width(0.25) for d in range(len(attention_matrices)): - at = attention_matrices[d] + at = attention_matrices[d].to("cpu") ni = torch.arange(at.size(0))[:, None].expand_as(at) nj = torch.arange(at.size(1))[None, :].expand_as(at) at = at.flatten()