From 0f580d4facb4b4b485d0a38d62d06c0639715b77 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 17 Jul 2023 22:51:28 +0200 Subject: [PATCH] Update. --- ffutils.py | 108 +++++++++++++++++++++++++++++++++++++++++++++++++++++ main.py | 1 + tasks.py | 9 +++-- world.py | 8 +++- 4 files changed, 120 insertions(+), 6 deletions(-) create mode 100755 ffutils.py diff --git a/ffutils.py b/ffutils.py new file mode 100755 index 0000000..45f44d8 --- /dev/null +++ b/ffutils.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import torch +import sys, contextlib + +import torch +from torch import Tensor + +###################################################################### + + +@contextlib.contextmanager +def evaluation(*models): + with torch.inference_mode(): + t = [(m, m.training) for m in models] + for m in models: + m.train(False) + yield + for m, u in t: + m.train(u) + + +###################################################################### + +from torch.utils._python_dispatch import TorchDispatchMode + + +def hasNaN(x): + if torch.is_tensor(x): + return x.isnan().max() + else: + try: + return any([hasNaN(y) for y in x]) + except TypeError: + return False + + +class NaNDetect(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args, kwargs=None): + kwargs = kwargs or {} + res = func(*args, **kwargs) + + if hasNaN(res): + raise RuntimeError( + f"Function {func}(*{args}, **{kwargs}) " "returned a NaN" + ) + return res + + +###################################################################### + + +def exception_hook(exc_type, exc_value, tb): + r"""Hacks the call stack message to show all the local variables + in case of relevant error, and prints tensors as shape, dtype and + device. + + """ + + repr_orig = Tensor.__repr__ + Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}" + + while tb: + print("--------------------------------------------------\n") + filename = tb.tb_frame.f_code.co_filename + name = tb.tb_frame.f_code.co_name + line_no = tb.tb_lineno + print(f' File "{filename}", line {line_no}, in {name}') + print(open(filename, "r").readlines()[line_no - 1]) + + if exc_type in {RuntimeError, ValueError, IndexError, TypeError}: + for n, v in tb.tb_frame.f_locals.items(): + print(f" {n} -> {v}") + + print() + tb = tb.tb_next + + Tensor.__repr__ = repr_orig + + print(f"{exc_type.__name__}: {exc_value}") + + +def activate_tensorstack(): + sys.excepthook = exception_hook + + +###################################################################### + +if __name__ == "__main__": + import torch + + def dummy(a, b): + print(a @ b) + + def blah(a, b): + c = b + b + dummy(a, c) + + mmm = torch.randn(2, 3) + xxx = torch.randn(3) + # print(xxx@mmm) + blah(mmm, xxx) + blah(xxx, mmm) diff --git a/main.py b/main.py index 69ee58f..e18887b 100755 --- a/main.py +++ b/main.py @@ -14,6 +14,7 @@ import torch, torchvision from torch import nn from torch.nn import functional as F +import ffutils import mygpt, tasks ###################################################################### diff --git a/tasks.py b/tasks.py index 5583fc8..9cd06ae 100755 --- a/tasks.py +++ b/tasks.py @@ -75,11 +75,12 @@ class ProblemByheart(Problem): def __init__(self): nb_seq, len_prompt, len_result = 100, 5, 5 self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) - self.seq[:,len_prompt]=-1 + self.seq[:, len_prompt] = -1 def generate_sequences(self, nb): return self.seq[torch.randint(self.seq.size(0), (nb,))] + class SandBox(Task): def __init__( self, @@ -93,7 +94,7 @@ class SandBox(Task): self.batch_size = batch_size - problems = [ ProblemByheart() ] + problems = [ProblemByheart()] nb_common_codes = 100 def generate_sequences(nb_samples): @@ -101,7 +102,7 @@ class SandBox(Task): nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) print(f"{nb_samples_per_problem}") all_seq = [] - for nb, p in zip(nb_samples_per_problem,problems): + for nb, p in zip(nb_samples_per_problem, problems): all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) return all_seq @@ -109,7 +110,7 @@ class SandBox(Task): test_seq = generate_sequences(nb_test_samples) for strain, stest in zip(train_seq, test_seq): - s = torch.cat((strain,stest),0) + s = torch.cat((strain, stest), 0) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 diff --git a/world.py b/world.py index 3d6abbe..b35a08e 100755 --- a/world.py +++ b/world.py @@ -85,9 +85,9 @@ def loss_H(binary_logits, h_threshold=1): def train_encoder( train_input, test_input, - depth=2, + depth, + nb_bits_per_token, dim_hidden=48, - nb_bits_per_token=8, lambda_entropy=0.0, lr_start=1e-3, lr_end=1e-4, @@ -366,6 +366,8 @@ def create_data_and_processors( nb_test_samples, mode, nb_steps, + depth=3, + nb_bits_per_token=8, nb_epochs=10, device=torch.device("cpu"), device_storage=torch.device("cpu"), @@ -388,6 +390,8 @@ def create_data_and_processors( encoder, quantizer, decoder = train_encoder( train_input, test_input, + depth=depth, + nb_bits_per_token=nb_bits_per_token, lambda_entropy=1.0, nb_epochs=nb_epochs, logger=logger, -- 2.20.1