- for i, d in pairs(distance) do
- table.insert(self.sorted, { d, i })
+ for m, d in pairs(distance) do
+ table.insert(self.sorted, { distance = d, nnm = m })
+ end
+
+ table.sort(self.sorted, function(a, b) return a.distance < b.distance end)
+
+ for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end
+end
+
+function DAG:updateGradOutput(node)
+ local gradInputSucc = node.gradInputSucc
+ if #gradInputSucc == 1 then
+ node.gradOutput = gradInputSucc[1]
+ elseif #gradInputSucc > 1 then
+ if node.gradOutput then
+ node.gradOutput:resize(gradInputSucc[1]):copy(gradInputSucc[1])
+ else
+ node.gradOutput = gradInputSucc[1]:clone()
+ end
+ for k = 2, #gradInputSucc do
+ node.gradOutput:add(gradInputSucc[k])
+ end