Update.
[pytorch.git] / tensorstack.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 from torch import Tensor
9
10 import sys
11
12 def exception_hook(exc_type, exc_value, tb):
13     r'''Hacks the call stack message in case of RuntimeError to show all
14     the local variables, and indicate for every tensor involved its
15     shape, dtype and device.
16
17     '''
18
19     repr_orig=Tensor.__repr__
20     Tensor.__repr__=lambda x: f'{x.size()}:{x.dtype}:{x.device}'
21
22     while tb:
23         print('--------------------------------------------------')
24         filename = tb.tb_frame.f_code.co_filename
25         name = tb.tb_frame.f_code.co_name
26         line_no = tb.tb_lineno
27         print(f'  File "{filename}", line {line_no}, in {name}')
28         print(open(filename, 'r').readlines()[line_no-1], end='')
29
30         if exc_type is RuntimeError:
31             for n,v in tb.tb_frame.f_locals.items():
32                 print(f'  {n} -> {v}')
33
34         tb = tb.tb_next
35
36     Tensor.__repr__=repr_orig
37
38     print(f'{exc_type.__name__}: {exc_value}')
39
40 sys.excepthook = exception_hook
41
42 ######################################################################
43
44 if __name__ == '__main__':
45
46     import torch
47
48     def dummy(a,b):
49         print(a@b)
50
51     def blah(a,b):
52         c=b+b
53         dummy(a,c)
54
55     mmm=torch.randn(2,3)
56     xxx=torch.randn(3)
57     #print(xxx@mmm)
58     blah(mmm,xxx)
59     blah(xxx,mmm)