Added DAG:dot() to generate a dot file for visualization.
authorFrancois Fleuret <francois@fleuret.org>
Thu, 12 Jan 2017 21:35:26 +0000 (22:35 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 12 Jan 2017 21:35:26 +0000 (22:35 +0100)
dagnn.lua
test-dagnn.lua

index 9202932..c6d54ad 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -261,3 +261,35 @@ function DAG:accGradParameters(input, gradOutput, scale)
 end
 
 ----------------------------------------------------------------------
+
+function DAG:dot(filename)
+   local file = (filename and io.open(filename, 'w')) or io.stdout
+
+   file:write('digraph {\n')
+
+   file:write('\n')
+
+   for nnma, node in pairs(self.node) do
+      file:write(
+         '  '
+            .. node.index
+            .. ' [shape=box,label=\"' .. torch.type(nnma) .. '\"]'
+            .. '\n'
+      )
+
+      for _, nnmb in pairs(node.succ) do
+         file:write(
+            '  '
+               .. node.index
+               .. ' -> '
+               .. self.node[nnmb].index
+               .. '\n'
+         )
+      end
+
+      file:write('\n')
+   end
+
+   file:write('}\n')
+
+end
index 3dea310..53302fd 100755 (executable)
@@ -108,3 +108,5 @@ local output = model:updateOutput(input):clone()
 output:uniform()
 
 print('Error = ' .. checkGrad(model, nn.MSECriterion(), input, output))
+
+model:dot('/tmp/graph.dot')