In DAG:saveDot, groups input and output nodes in nested clusters.
authorFrancois Fleuret <francois@fleuret.org>
Sun, 15 Jan 2017 11:04:13 +0000 (12:04 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Sun, 15 Jan 2017 11:04:13 +0000 (12:04 +0100)
dagnn.lua

index 1f45b2a..ca26926 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -185,10 +185,28 @@ end
 function DAG:saveDot(filename)
    local file = (filename and io.open(filename, 'w')) or io.stdout
 
+   local function writeNestedCluster(prefix, list, indent)
+      local indent = indent or ''
+      if torch.type(list) == 'table' then
+         file:write(indent .. '  subgraph cluster_' .. prefix .. ' {\n');
+         for k, x in pairs(list) do
+            writeNestedCluster(prefix .. '_' .. k, x, '  ' .. indent)
+         end
+         file:write(indent .. '  }\n');
+      else
+         file:write(indent .. '  ' .. self.node[list].index .. ' [color=red]\n')
+      end
+   end
+
    file:write('digraph {\n')
 
    file:write('\n')
 
+   writeNestedCluster('input', self.inputModules)
+   writeNestedCluster('output', self.outputModules)
+
+   file:write('\n')
+
    for nnmb, node in pairs(self.node) do
       file:write(
          '  '