Update.
[pytorch.git] / tensorstack.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 from torch import Tensor
9
10 import sys
11
12
13 def exception_hook(exc_type, exc_value, tb):
14     r"""Hacks the call stack message to show all the local variables
15     in case of RuntimeError, ValueError, or TypeError and prints
16     tensors as shape, dtype and device.
17
18     """
19
20     repr_orig = Tensor.__repr__
21     Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}"
22
23     while tb:
24         print("--------------------------------------------------\n")
25         filename = tb.tb_frame.f_code.co_filename
26         name = tb.tb_frame.f_code.co_name
27         line_no = tb.tb_lineno
28         print(f'  File "{filename}", line {line_no}, in {name}')
29         print(open(filename, "r").readlines()[line_no - 1])
30
31         if exc_type in {RuntimeError, ValueError, TypeError}:
32             for n, v in tb.tb_frame.f_locals.items():
33                 print(f"  {n} -> {v}")
34
35         print()
36         tb = tb.tb_next
37
38     Tensor.__repr__ = repr_orig
39
40     print(f"{exc_type.__name__}: {exc_value}")
41
42
43 sys.excepthook = exception_hook
44
45 ######################################################################
46
47 if __name__ == "__main__":
48     import torch
49
50     def dummy(a, b):
51         print(a @ b)
52
53     def blah(a, b):
54         c = b + b
55         dummy(a, c)
56
57     mmm = torch.randn(2, 3)
58     xxx = torch.randn(3)
59     # print(xxx@mmm)
60     blah(mmm, xxx)
61     blah(xxx, mmm)