a45d6365d61a6183b7b1b49758cddb89272d7708
[dagnn.git] / test-dagnn.lua
1 #!/usr/bin/env luajit
2
3 require 'torch'
4 require 'nn'
5
6 require 'dagnn'
7
8 -- torch.setnumthreads(params.nbThreads)
9 torch.setdefaulttensortype('torch.DoubleTensor')
10 torch.manualSeed(2)
11
12 a = nn.Linear(10, 10)
13 b = nn.ReLU()
14 c = nn.Linear(10, 3)
15 d = nn.Linear(10, 3)
16 e = nn.CMulTable()
17 f = nn.Linear(3, 2)
18
19 --[[
20
21    a -----> b ---> c ----> e ---
22              \           /
23               \--> d ---/
24                     \
25                      \---> f ---
26 ]]--
27
28 g = nn.DAG:new()
29
30 g:setInput(a)
31 g:setOutput({ e })
32
33 g:addEdge(c, e)
34 g:addEdge(a, b)
35 g:addEdge(d, e)
36 g:addEdge(b, c)
37 g:addEdge(b, d)
38 -- g:addEdge(d, f)
39
40 -- g = torch.load('dag.t7')
41
42 g:print()
43
44 input = torch.Tensor(3, 10):uniform()
45
46 output = g:updateOutput(input)
47
48 if torch.type(output) == 'table' then
49    for i, t in pairs(output) do
50       print(tostring(i) .. ' -> ' .. tostring(t))
51    end
52 else
53    print(tostring(output))
54 end
55
56 torch.save('dag.t7', g)