Introduction

This package provides a function that generates a dot file from a pytorch autograd graph.

Usage

Functions

agtree2dot.save_dot(variable, variable_labels, result_file)

Saves into result_file a dot file corresponding to the autograd graph for the Variable variable. The dictionary variable_labels associates strings to some variables, which will be used in the resulting graph.

Example

A typical use is provided in mlp.py:

import subprocess

from torch import nn
from torch.nn import functional as fn
from torch import Tensor
from torch.autograd import Variable
from torch.nn import Module

import agtree2dot

class MLP(Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = fn.tanh(x)
        x = self.fc2(x)
        return x

mlp = MLP(10, 20, 1)
input = Variable(Tensor(100, 10).normal_())
target = Variable(Tensor(100).normal_())
output = mlp(input)
criterion = nn.MSELoss()
loss = criterion(output, target)

agtree2dot.save_dot(loss,
                    {
                        input: 'input',
                        target: 'target',
                        loss: 'loss',
                        mlp.fc1.weight: 'weight1',
                        mlp.fc1.bias: 'bias1',
                        mlp.fc2.weight: 'weight2',
                        mlp.fc2.bias: 'bias2',
                    },
                    open('./mlp.dot', 'w'))

print('Generated mlp.dot')

try:
    subprocess.check_call(["dot", "mlp.dot", "-Lg", "-T", "pdf", "-o", "mlp.pdf" ])
except subprocess.CalledProcessError:
    print('Calling the dot command failed. Is Graphviz installed?')
    sys.exit(1)

print('Generated mlp.pdf')

which would generate a file mlp.dot and try to generate mlp.pdf from it with Graphviz tools.