X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=sizer.py;h=dff36ebc6ba9f279e270339f8a6c548bb8140f86;hp=cc0a19e26d19f5157de277cda7e8b0cea80f1688;hb=HEAD;hpb=f99e2c83638c960d158c17270c072876834df9a9 diff --git a/sizer.py b/sizer.py index cc0a19e..5887e4a 100755 --- a/sizer.py +++ b/sizer.py @@ -10,10 +10,12 @@ import time import torch from torch import nn -t = 0 +###################################################################### if len(sys.argv) < 2: - print(sys.argv[0] + ''' + print( + sys.argv[0] + + """ For example: @@ -24,23 +26,28 @@ nn.Conv2d(32, 32, 3, padding = 1) nn.MaxPool2d(2) nn.Conv2d(32, 64, 3, padding = 1) nn.MaxPool2d(5) -nn.Conv2d(64, 64, (3, 4))''') +nn.Conv2d(64, 64, (3, 4))""" + ) exit(1) +###################################################################### + +t = 0 + while True: pt = t t = os.stat(sys.argv[1])[stat.ST_MTIME] if t > pt: pt = t - os.system('clear') + os.system("clear") try: - temp = [l.strip('\n\r') for l in open(sys.argv[1], 'r').readlines()] + temp = [l.strip("\n\r") for l in open(sys.argv[1], "r").readlines()] x = torch.zeros(eval(temp.pop(0))) - print('-> ' + str(tuple(x.size()))) + print("-> " + str(tuple(x.size()))) for k in temp: - print(' ' + k) - x = eval(k + '(x)') - print('-> ' + str(tuple(x.size()))) + print(" " + k) + x = eval(k + "(x)") + print("-> " + str(tuple(x.size()))) except: - print('** Error **') + print("** Error **") time.sleep(1)