Update.
[pytorch.git] / sizer.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import os, stat, sys
9 import time
10 import torch
11 from torch import nn
12
13 ######################################################################
14
15 if len(sys.argv) < 2:
16     print(
17         sys.argv[0]
18         + """ <file to monitor>
19
20 For example:
21
22 (17, 3, 60, 80)
23 nn.Conv2d(3, 32, 3, padding = 1)
24 nn.MaxPool2d(2)
25 nn.Conv2d(32, 32, 3, padding = 1)
26 nn.MaxPool2d(2)
27 nn.Conv2d(32, 64, 3, padding = 1)
28 nn.MaxPool2d(5)
29 nn.Conv2d(64, 64, (3, 4))"""
30     )
31     exit(1)
32
33 ######################################################################
34
35 t = 0
36
37 while True:
38     pt = t
39     t = os.stat(sys.argv[1])[stat.ST_MTIME]
40     if t > pt:
41         pt = t
42         os.system("clear")
43         try:
44             temp = [l.strip("\n\r") for l in open(sys.argv[1], "r").readlines()]
45             x = torch.zeros(eval(temp.pop(0)))
46             print("-> " + str(tuple(x.size())))
47             for k in temp:
48                 print("   " + k)
49                 x = eval(k + "(x)")
50                 print("-> " + str(tuple(x.size())))
51         except:
52             print("** Error **")
53     time.sleep(1)