Now also catch ValueError.
[pytorch.git] / tensorstack.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 from torch import Tensor
9
10 import sys
11
12 def exception_hook(exc_type, exc_value, tb):
13     r'''Hacks the call stack message in case of RuntimeError to show all
14     the local variables, and indicate for every tensor involved its
15     shape, dtype and device.
16
17     '''
18
19     repr_orig=Tensor.__repr__
20     Tensor.__repr__=lambda x: f'{x.size()}:{x.dtype}:{x.device}'
21
22     while tb:
23         print('--------------------------------------------------\n')
24         filename = tb.tb_frame.f_code.co_filename
25         name = tb.tb_frame.f_code.co_name
26         line_no = tb.tb_lineno
27         print(f'  File "{filename}", line {line_no}, in {name}')
28         print(open(filename, 'r').readlines()[line_no-1])
29
30         if exc_type in { RuntimeError, ValueError }:
31             for n,v in tb.tb_frame.f_locals.items():
32                 print(f'  {n} -> {v}')
33
34         print()
35         tb = tb.tb_next
36
37     Tensor.__repr__=repr_orig
38
39     print(f'{exc_type.__name__}: {exc_value}')
40
41 sys.excepthook = exception_hook
42
43 ######################################################################
44
45 if __name__ == '__main__':
46
47     import torch
48
49     def dummy(a,b):
50         print(a@b)
51
52     def blah(a,b):
53         c=b+b
54         dummy(a,c)
55
56     mmm=torch.randn(2,3)
57     xxx=torch.randn(3)
58     #print(xxx@mmm)
59     blah(mmm,xxx)
60     blah(xxx,mmm)