X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=agtree2dot.git;a=blobdiff_plain;f=mlp.py;h=3c5f026bfa2c8a760b10af1d0e0fb857a2db1dff;hp=8497848fa7c32413a05dbbf47e159d06764113fd;hb=d9e6125c82f4e3775b8c868751b1657a6a147f55;hpb=55332826c1d0ec125fc1d2db6644c98b1640d4a2 diff --git a/mlp.py b/mlp.py index 8497848..3c5f026 100755 --- a/mlp.py +++ b/mlp.py @@ -17,6 +17,8 @@ # Contact for comments & bug reports # ######################################################################### +import subprocess + from torch import nn from torch.nn import functional as fn from torch import Tensor @@ -45,8 +47,23 @@ criterion = nn.MSELoss() loss = criterion(output, target) agtree2dot.save_dot(loss, - { input: 'input', target: 'target', loss: 'loss' }, + { + input: 'input', + target: 'target', + loss: 'loss', + mlp.fc1.weight: 'weight1', + mlp.fc1.bias: 'bias1', + mlp.fc2.weight: 'weight2', + mlp.fc2.bias: 'bias2', + }, open('./mlp.dot', 'w')) -print('Generated mlp.dot. You can convert it to pdf with') -print('> dot mlp.dot -Lg -T pdf -o mlp.pdf') +print('Generated mlp.dot') + +try: + subprocess.check_call(["dot", "mlp.dot", "-Lg", "-T", "pdf", "-o", "mlp.pdf" ]) +except subprocess.CalledProcessError: + print('Calling the dot command failed. Is Graphviz installed?') + sys.exit(1) + +print('Generated mlp.pdf')