#!/usr/bin/env luajit require 'torch' require 'nn' require 'dagnn' -- torch.setnumthreads(params.nbThreads) torch.setdefaulttensortype('torch.DoubleTensor') torch.manualSeed(2) a = nn.Linear(10, 10) b = nn.ReLU() c = nn.Linear(10, 3) d = nn.Linear(10, 3) e = nn.CMulTable() f = nn.Linear(3, 2) --[[ a -----> b ---> c ----> e --- \ / \--> d ---/ \ \---> f --- ]]-- g = nn.DAG:new() g:setInput(a) g:setOutput({ e }) g:addEdge(c, e) g:addEdge(a, b) g:addEdge(d, e) g:addEdge(b, c) g:addEdge(b, d) -- g:addEdge(d, f) -- g = torch.load('dag.t7') g:print() input = torch.Tensor(3, 10):uniform() output = g:updateOutput(input) if torch.type(output) == 'table' then for i, t in pairs(output) do print(tostring(i) .. ' -> ' .. tostring(t)) end else print(tostring(output)) end torch.save('dag.t7', g)