From: Francois Fleuret Date: Thu, 12 Jan 2017 21:35:26 +0000 (+0100) Subject: Added DAG:dot() to generate a dot file for visualization. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=0a630b54355382dfa68c0f3d51729bad0b4c58e6;p=dagnn.git Added DAG:dot() to generate a dot file for visualization. --- diff --git a/dagnn.lua b/dagnn.lua index 9202932..c6d54ad 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -261,3 +261,35 @@ function DAG:accGradParameters(input, gradOutput, scale) end ---------------------------------------------------------------------- + +function DAG:dot(filename) + local file = (filename and io.open(filename, 'w')) or io.stdout + + file:write('digraph {\n') + + file:write('\n') + + for nnma, node in pairs(self.node) do + file:write( + ' ' + .. node.index + .. ' [shape=box,label=\"' .. torch.type(nnma) .. '\"]' + .. '\n' + ) + + for _, nnmb in pairs(node.succ) do + file:write( + ' ' + .. node.index + .. ' -> ' + .. self.node[nnmb].index + .. '\n' + ) + end + + file:write('\n') + end + + file:write('}\n') + +end diff --git a/test-dagnn.lua b/test-dagnn.lua index 3dea310..53302fd 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -108,3 +108,5 @@ local output = model:updateOutput(input):clone() output:uniform() print('Error = ' .. checkGrad(model, nn.MSECriterion(), input, output)) + +model:dot('/tmp/graph.dot')