Update.
[pytorch.git] / sizer.py
index cc0a19e..5887e4a 100755 (executable)
--- a/sizer.py
+++ b/sizer.py
@@ -10,10 +10,12 @@ import time
 import torch
 from torch import nn
 
-t = 0
+######################################################################
 
 if len(sys.argv) < 2:
-    print(sys.argv[0] + ''' <file to monitor>
+    print(
+        sys.argv[0]
+        + """ <file to monitor>
 
 For example:
 
@@ -24,23 +26,28 @@ nn.Conv2d(32, 32, 3, padding = 1)
 nn.MaxPool2d(2)
 nn.Conv2d(32, 64, 3, padding = 1)
 nn.MaxPool2d(5)
-nn.Conv2d(64, 64, (3, 4))''')
+nn.Conv2d(64, 64, (3, 4))"""
+    )
     exit(1)
 
+######################################################################
+
+t = 0
+
 while True:
     pt = t
     t = os.stat(sys.argv[1])[stat.ST_MTIME]
     if t > pt:
         pt = t
-        os.system('clear')
+        os.system("clear")
         try:
-            temp = [l.strip('\n\r') for l in open(sys.argv[1], 'r').readlines()]
+            temp = [l.strip("\n\r") for l in open(sys.argv[1], "r").readlines()]
             x = torch.zeros(eval(temp.pop(0)))
-            print('-> ' + str(tuple(x.size())))
+            print("-> " + str(tuple(x.size())))
             for k in temp:
-                print('   ' + k)
-                x = eval(k + '(x)')
-                print('-> ' + str(tuple(x.size())))
+                print("   " + k)
+                x = eval(k + "(x)")
+                print("-> " + str(tuple(x.size())))
         except:
-            print('** Error **')
+            print("** Error **")
     time.sleep(1)