OCD cosmetics.
[dagnn.git] / dagnn.lua
index ca26926..b82398c 100755 (executable)
--- a/dagnn.lua
+++ b/dagnn.lua
@@ -69,6 +69,7 @@ function DAG:putInOrder()
    local nc
    local nl = 0
    repeat
+      assert(nl < #self.modules, 'Cycle detected in the graph.')
       nc = 0
       for nnma, node in pairs(self.node) do
          for _, nnmb in pairs(node.succ) do
@@ -78,12 +79,11 @@ function DAG:putInOrder()
             end
          end
       end
-      assert(nl < #self.modules, 'Cycle detected in the graph.')
       nl = nl + 1
    until nc == 0
 
    for _, nnm in pairs(self.modules) do
-      assert(distance[nnm], 'Some modules are not connected to inputs')
+      assert(distance[nnm], 'Some modules are not connected to inputs.')
    end
 
    self.sorted = {}
@@ -148,6 +148,10 @@ function DAG:connect(...)
    end
 end
 
+function DAG:setLabel(nnm, label)
+   self.node[nnm].label = label
+end
+
 function DAG:setInput(i)
    self.sorted = nil
    self.inputModules = i
@@ -176,7 +180,11 @@ function DAG:print()
    self:putInOrder()
 
    for i, d in ipairs(self.sorted) do
-      print('#' .. i .. ' -> ' .. torch.type(d))
+      local decoration = ''
+      if self.node[d].label then
+         decoration = ' [' .. self.node[d].label .. ']'
+      end
+      print('#' .. i .. ' -> ' .. torch.type(d) .. decoration)
    end
 end
 
@@ -211,7 +219,7 @@ function DAG:saveDot(filename)
       file:write(
          '  '
             .. node.index
-            .. ' [shape=box,label=\"' .. torch.type(nnmb) .. '\"]'
+            .. ' [shape=box,label=\"' .. (self.node[nnmb].label or torch.type(nnmb)) .. '\"]'
             .. '\n'
       )
 
@@ -280,7 +288,7 @@ function DAG:updateOutput(input)
 end
 
 function DAG:updateGradInput(input, gradOutput)
-   assert(self.sorted, 'There has been a structure change before a DAG:updateGradInput')
+   assert(self.sorted, 'There has been a structure change before a DAG:updateGradInput.')
 
    self:nestedApply(
       function(nnm, go)
@@ -315,7 +323,7 @@ function DAG:updateGradInput(input, gradOutput)
          table.insert(self.node[pred[1]].gradInputSucc, nnm.gradInput)
       elseif #pred > 1 then
          assert(torch.type(nnm.gradInput) == 'table',
-                'Should have a table gradInput since it has multiple predecessors')
+                'Should have a table gradInput since it has multiple predecessors.')
          for n = 1, #pred do
             table.insert(self.node[pred[n]].gradInputSucc, nnm.gradInput[n])
          end
@@ -331,7 +339,7 @@ function DAG:updateGradInput(input, gradOutput)
 end
 
 function DAG:accGradParameters(input, gradOutput, scale)
-   assert(self.sorted, 'There has been a structure change before a DAG:accGradParameters')
+   assert(self.sorted, 'There has been a structure change before a DAG:accGradParameters.')
 
    self:nestedApply(
       function(nnm, go) self.node[nnm].gradOutput = go end,