3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
12 from torch import Tensor
14 ######################################################################
17 @contextlib.contextmanager
18 def evaluation(*models):
19 with torch.inference_mode():
20 t = [(m, m.training) for m in models]
28 ######################################################################
30 from torch.utils._python_dispatch import TorchDispatchMode
34 if torch.is_tensor(x):
35 return x.numel() > 0 and x.isnan().max()
38 return any([hasNaN(y) for y in x])
43 class NaNDetect(TorchDispatchMode):
44 def __torch_dispatch__(self, func, types, args, kwargs=None):
46 res = func(*args, **kwargs)
50 f"Function {func}(*{args}, **{kwargs}) " "returned a NaN"
55 ######################################################################
58 def exception_hook(exc_type, exc_value, tb):
59 r"""Hacks the call stack message to show all the local variables
60 in case of relevant error, and prints tensors as shape, dtype and
65 repr_orig = Tensor.__repr__
66 Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}"
69 print("--------------------------------------------------\n")
70 filename = tb.tb_frame.f_code.co_filename
71 name = tb.tb_frame.f_code.co_name
72 line_no = tb.tb_lineno
73 print(f' File "{filename}", line {line_no}, in {name}')
74 print(open(filename, "r").readlines()[line_no - 1])
76 if exc_type in {RuntimeError, ValueError, IndexError, TypeError}:
77 for n, v in tb.tb_frame.f_locals.items():
83 Tensor.__repr__ = repr_orig
85 print(f"{exc_type.__name__}: {exc_value}")
88 def activate_tensorstack():
89 sys.excepthook = exception_hook
92 ######################################################################
94 if __name__ == "__main__":
104 mmm = torch.randn(2, 3)