X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=dagnn.lua;h=f9d6ff9b9248185b73ee5add602826d6f8e4a260;hb=cd1aa6aaa75f5d5b281be0cfbacd51991b3b1ca3;hp=1f45b2acae3bc546495e2b33bd0666f6f577c54e;hpb=e73c3494970d12154aff7587fcb43cf600f03e30;p=dagnn.git diff --git a/dagnn.lua b/dagnn.lua index 1f45b2a..f9d6ff9 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -148,6 +148,10 @@ function DAG:connect(...) end end +function DAG:setLabel(nnm, label) + self.node[nnm].label = label +end + function DAG:setInput(i) self.sorted = nil self.inputModules = i @@ -176,7 +180,11 @@ function DAG:print() self:putInOrder() for i, d in ipairs(self.sorted) do - print('#' .. i .. ' -> ' .. torch.type(d)) + local decoration = '' + if self.node[d].label then + decoration = ' [' .. self.node[d].label .. ']' + end + print('#' .. i .. ' -> ' .. torch.type(d) .. decoration) end end @@ -185,15 +193,33 @@ end function DAG:saveDot(filename) local file = (filename and io.open(filename, 'w')) or io.stdout + local function writeNestedCluster(prefix, list, indent) + local indent = indent or '' + if torch.type(list) == 'table' then + file:write(indent .. ' subgraph cluster_' .. prefix .. ' {\n'); + for k, x in pairs(list) do + writeNestedCluster(prefix .. '_' .. k, x, ' ' .. indent) + end + file:write(indent .. ' }\n'); + else + file:write(indent .. ' ' .. self.node[list].index .. ' [color=red]\n') + end + end + file:write('digraph {\n') file:write('\n') + writeNestedCluster('input', self.inputModules) + writeNestedCluster('output', self.outputModules) + + file:write('\n') + for nnmb, node in pairs(self.node) do file:write( ' ' .. node.index - .. ' [shape=box,label=\"' .. torch.type(nnmb) .. '\"]' + .. ' [shape=box,label=\"' .. (self.node[nnmb].label or torch.type(nnmb)) .. '\"]' .. '\n' )