Update.
authorFrancois Fleuret <francois@fleuret.org>
Sun, 6 Jun 2021 12:30:34 +0000 (14:30 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Sun, 6 Jun 2021 12:30:34 +0000 (14:30 +0200)
conv_chain.py

index 04dfdfa..fa5d752 100755 (executable)
@@ -24,19 +24,21 @@ def conv_chain(input_size, output_size, depth, cond):
 
 ######################################################################
 
-# Example
+if __name__ == "__main__":
 
-c = conv_chain(
-    input_size = 64, output_size = 8,
-    depth = 5,
-    cond = lambda k, s: k <= 4 and s <= 2 and s <= k//2
-)
+    # Example
 
-x = torch.rand(1, 1, 64)
+    c = conv_chain(
+        input_size = 64, output_size = 8,
+        depth = 5,
+        cond = lambda k, s: k <= 4 and s <= 2 and s <= k//2
+    )
 
-for m in c:
-    m = nn.Sequential(*[ nn.Conv1d(1, 1, l[0], l[1]) for l in m ])
-    print(m)
-    print(x.size(), m(x).size())
+    x = torch.rand(1, 1, 64)
+
+    for m in c:
+        model = nn.Sequential(*[ nn.Conv1d(1, 1, l[0], l[1]) for l in m ])
+        print(model)
+        print(x.size(), model(x).size())
 
 ######################################################################