X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=dagnn.lua;h=9202932da2950905fc49734c8762226e0e4d6f28;hb=da6186a657b7563841416c42336e52937b76d67f;hp=7fc1018f8dd7f3f88021c914608c5af1f5a710a5;hpb=063f198047f0202fa921aa09b772369b14ae8be2;p=dagnn.git diff --git a/dagnn.lua b/dagnn.lua index 7fc1018..9202932 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -40,12 +40,19 @@ function DAG:createNode(nnm) end end -function DAG:addEdge(nnma, nnmb) +-- The main use should be to add an edge between two modules, but it +-- can also add a full sequence of modules +function DAG:addEdge(...) self.sorted = nil - self:createNode(nnma) - self:createNode(nnmb) - table.insert(self.node[nnmb].pred, nnma) - table.insert(self.node[nnma].succ, nnmb) + local prev + for _, nnm in pairs({...}) do + self:createNode(nnm) + if prev then + table.insert(self.node[nnm].pred, prev) + table.insert(self.node[prev].succ, nnm) + end + prev = nnm + end end -- Apply f on t recursively; use the corresponding element from args