X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=graph.py;fp=graph.py;h=c286388d6298eba52f7ad88ed9bffa0d5ab2af9a;hb=8492656cf0cc5de4f7e2c4aa8ccb717193293b40;hp=a819283cc28450c28b54a6e5b6b215c408c65243;hpb=3b9ba21fd3d06a20703216cc0a77fe9dc78b079f;p=picoclvr.git diff --git a/graph.py b/graph.py index a819283..c286388 100755 --- a/graph.py +++ b/graph.py @@ -161,7 +161,7 @@ if __name__ == "__main__": nb_heads=2, nb_blocks=5, dropout=0.1, - causal=True, + #causal=True, ) model.eval() @@ -171,6 +171,8 @@ if __name__ == "__main__": attention_matrices = [m[0, 0] for m in model.retrieve_attention()] + + # attention_matrices = [ torch.rand(3,5), torch.rand(8,3), torch.rand(5,8) ] # for a in attention_matrices: a=a/a.sum(-1,keepdim=True)