#!/usr/bin/env luajit require 'torch' require 'nn' require 'dagnn' function printTensorTable(t) if torch.type(t) == 'table' then for i, t in pairs(t) do print('-- ELEMENT [' .. i .. '] --') printTensorTable(t) end else print(tostring(t)) end end -- torch.setnumthreads(params.nbThreads) torch.setdefaulttensortype('torch.DoubleTensor') torch.manualSeed(2) a = nn.Linear(10, 10) b = nn.ReLU() c = nn.Linear(10, 3) d = nn.Linear(10, 3) e = nn.CMulTable() f = nn.Linear(3, 2) -- a -----> b ---> c ----> e --- -- \ / -- \--> d ---/ -- \ -- \---> f --- g = nn.DAG() g:addEdge(c, e) g:addEdge(a, b) g:addEdge(d, e) g:addEdge(b, c) g:addEdge(b, d) g:addEdge(d, f) g:setInput({a}) g:setOutput({e,f}) g:print() input = torch.Tensor(3, 10):uniform() output = g:updateOutput({input}) printTensorTable(output) ---------------------------------------------------------------------- -- gradInput = g:updateGradInput({ input }, output)