Made the example more complicated to check that DAGs can be combined with other modules.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 13 Jan 2017 15:21:09 +0000 (16:21 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 13 Jan 2017 15:21:09 +0000 (16:21 +0100)
test-dagnn.lua

index 3801956..f7de819 100755 (executable)
@@ -75,37 +75,41 @@ function printTensorTable(t)
    end
 end
 
---               +-- Linear(10, 10) --> ReLU --> d --+
---              /                              /      \
---             /                              /        \
---  --> a --> b -----------> c --------------+          e -->
---                            \                        /
---                             \                      /
---                              +----- Mul(-1) ------+
+--               +-- Linear(10, 10) --> ReLU --> d -->
+--              /                               /
+--             /                               /
+--  --> a --> b -----------> c ---------------+
+--                            \
+--                             \
+--                              +--------------- e -->
 
-model = nn.DAG()
+dag = nn.DAG()
 
 a = nn.Linear(50, 10)
 b = nn.ReLU()
 c = nn.Linear(10, 15)
 d = nn.CMulTable()
-e = nn.CAddTable()
+e = nn.Mul(-1)
 
-model:connect(a, b, c)
-model:connect(b, nn.Linear(10, 15), nn.ReLU(), d)
-model:connect(d, e)
-model:connect(c, d)
-model:connect(c, nn.Mul(-1), e)
+dag:connect(a, b, c)
+dag:connect(b, nn.Linear(10, 15), nn.ReLU(), d)
+dag:connect(c, d)
+dag:connect(c, e)
 
-model:setInput(a)
-model:setOutput(e)
+dag:setInput(a)
+dag:setOutput({ d, e })
+
+-- We check it works when we put it into a nn.Sequential
+model = nn.Sequential()
+   :add(nn.Linear(50, 50))
+   :add(dag)
+   :add(nn.CAddTable())
 
 local input = torch.Tensor(30, 50):uniform()
 local output = model:updateOutput(input):clone()
-
 output:uniform()
 
-print('Error = ' .. checkGrad(model, nn.MSECriterion(), input, output))
+print('Gradient estimate error ' .. checkGrad(model, nn.MSECriterion(), input, output))
 
 print('Writing /tmp/graph.dot')
-model:saveDot('/tmp/graph.dot')
+dag:saveDot('/tmp/graph.dot')