X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=dagnn.git;a=blobdiff_plain;f=dagnn.lua;h=b82398c0de429bd19875776fa465222a84504bbe;hp=1f45b2acae3bc546495e2b33bd0666f6f577c54e;hb=HEAD;hpb=e73c3494970d12154aff7587fcb43cf600f03e30 diff --git a/dagnn.lua b/dagnn.lua index 1f45b2a..b82398c 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -69,6 +69,7 @@ function DAG:putInOrder() local nc local nl = 0 repeat + assert(nl < #self.modules, 'Cycle detected in the graph.') nc = 0 for nnma, node in pairs(self.node) do for _, nnmb in pairs(node.succ) do @@ -78,12 +79,11 @@ function DAG:putInOrder() end end end - assert(nl < #self.modules, 'Cycle detected in the graph.') nl = nl + 1 until nc == 0 for _, nnm in pairs(self.modules) do - assert(distance[nnm], 'Some modules are not connected to inputs') + assert(distance[nnm], 'Some modules are not connected to inputs.') end self.sorted = {} @@ -148,6 +148,10 @@ function DAG:connect(...) end end +function DAG:setLabel(nnm, label) + self.node[nnm].label = label +end + function DAG:setInput(i) self.sorted = nil self.inputModules = i @@ -176,7 +180,11 @@ function DAG:print() self:putInOrder() for i, d in ipairs(self.sorted) do - print('#' .. i .. ' -> ' .. torch.type(d)) + local decoration = '' + if self.node[d].label then + decoration = ' [' .. self.node[d].label .. ']' + end + print('#' .. i .. ' -> ' .. torch.type(d) .. decoration) end end @@ -185,15 +193,33 @@ end function DAG:saveDot(filename) local file = (filename and io.open(filename, 'w')) or io.stdout + local function writeNestedCluster(prefix, list, indent) + local indent = indent or '' + if torch.type(list) == 'table' then + file:write(indent .. ' subgraph cluster_' .. prefix .. ' {\n'); + for k, x in pairs(list) do + writeNestedCluster(prefix .. '_' .. k, x, ' ' .. indent) + end + file:write(indent .. ' }\n'); + else + file:write(indent .. ' ' .. self.node[list].index .. ' [color=red]\n') + end + end + file:write('digraph {\n') file:write('\n') + writeNestedCluster('input', self.inputModules) + writeNestedCluster('output', self.outputModules) + + file:write('\n') + for nnmb, node in pairs(self.node) do file:write( ' ' .. node.index - .. ' [shape=box,label=\"' .. torch.type(nnmb) .. '\"]' + .. ' [shape=box,label=\"' .. (self.node[nnmb].label or torch.type(nnmb)) .. '\"]' .. '\n' ) @@ -262,7 +288,7 @@ function DAG:updateOutput(input) end function DAG:updateGradInput(input, gradOutput) - assert(self.sorted, 'There has been a structure change before a DAG:updateGradInput') + assert(self.sorted, 'There has been a structure change before a DAG:updateGradInput.') self:nestedApply( function(nnm, go) @@ -297,7 +323,7 @@ function DAG:updateGradInput(input, gradOutput) table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput) elseif #pred > 1 then assert(torch.type(nnm.gradInput) == 'table', - 'Should have a table gradInput since it has multiple predecessors') + 'Should have a table gradInput since it has multiple predecessors.') for n = 1, #pred do table.insert(self.node[pred[n]].gradInputSucc, nnm.gradInput[n]) end @@ -313,7 +339,7 @@ function DAG:updateGradInput(input, gradOutput) end function DAG:accGradParameters(input, gradOutput, scale) - assert(self.sorted, 'There has been a structure change before a DAG:accGradParameters') + assert(self.sorted, 'There has been a structure change before a DAG:accGradParameters.') self:nestedApply( function(nnm, go) self.node[nnm].gradOutput = go end,