Update.
[beaver.git] / tensorstack.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 from torch import Tensor
9
10 import sys
11
12
13 def exception_hook(exc_type, exc_value, tb):
14     r"""Hacks the call stack message to show all the local variables in
15     case of RuntimeError or ValueError, and prints tensors as shape,
16     dtype and device.
17
18     """
19
20     repr_orig = Tensor.__repr__
21     Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}"
22
23     while tb:
24         print("--------------------------------------------------\n")
25         filename = tb.tb_frame.f_code.co_filename
26         name = tb.tb_frame.f_code.co_name
27         line_no = tb.tb_lineno
28         print(f'  File "{filename}", line {line_no}, in {name}')
29         print(open(filename, "r").readlines()[line_no - 1])
30
31         if exc_type in {RuntimeError, ValueError}:
32             for n, v in tb.tb_frame.f_locals.items():
33                 print(f"  {n} -> {v}")
34
35         print()
36         tb = tb.tb_next
37
38     Tensor.__repr__ = repr_orig
39
40     print(f"{exc_type.__name__}: {exc_value}")
41
42
43 sys.excepthook = exception_hook
44
45 ######################################################################
46
47 if __name__ == "__main__":
48
49     import torch
50
51     def dummy(a, b):
52         print(a @ b)
53
54     def blah(a, b):
55         c = b + b
56         dummy(a, c)
57
58     mmm = torch.randn(2, 3)
59     xxx = torch.randn(3)
60     # print(xxx@mmm)
61     blah(mmm, xxx)
62     blah(xxx, mmm)