From 7f4fa36941cb368915c6a81549077309ee25ab44 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 23 Sep 2024 17:42:10 +0200 Subject: [PATCH] Update. --- ffutils.py | 108 ----------------------------------------------------- main.py | 7 +++- 2 files changed, 6 insertions(+), 109 deletions(-) delete mode 100755 ffutils.py diff --git a/ffutils.py b/ffutils.py deleted file mode 100755 index 45f44d8..0000000 --- a/ffutils.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -import torch -import sys, contextlib - -import torch -from torch import Tensor - -###################################################################### - - -@contextlib.contextmanager -def evaluation(*models): - with torch.inference_mode(): - t = [(m, m.training) for m in models] - for m in models: - m.train(False) - yield - for m, u in t: - m.train(u) - - -###################################################################### - -from torch.utils._python_dispatch import TorchDispatchMode - - -def hasNaN(x): - if torch.is_tensor(x): - return x.isnan().max() - else: - try: - return any([hasNaN(y) for y in x]) - except TypeError: - return False - - -class NaNDetect(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args, kwargs=None): - kwargs = kwargs or {} - res = func(*args, **kwargs) - - if hasNaN(res): - raise RuntimeError( - f"Function {func}(*{args}, **{kwargs}) " "returned a NaN" - ) - return res - - -###################################################################### - - -def exception_hook(exc_type, exc_value, tb): - r"""Hacks the call stack message to show all the local variables - in case of relevant error, and prints tensors as shape, dtype and - device. - - """ - - repr_orig = Tensor.__repr__ - Tensor.__repr__ = lambda x: f"{x.size()}:{x.dtype}:{x.device}" - - while tb: - print("--------------------------------------------------\n") - filename = tb.tb_frame.f_code.co_filename - name = tb.tb_frame.f_code.co_name - line_no = tb.tb_lineno - print(f' File "{filename}", line {line_no}, in {name}') - print(open(filename, "r").readlines()[line_no - 1]) - - if exc_type in {RuntimeError, ValueError, IndexError, TypeError}: - for n, v in tb.tb_frame.f_locals.items(): - print(f" {n} -> {v}") - - print() - tb = tb.tb_next - - Tensor.__repr__ = repr_orig - - print(f"{exc_type.__name__}: {exc_value}") - - -def activate_tensorstack(): - sys.excepthook = exception_hook - - -###################################################################### - -if __name__ == "__main__": - import torch - - def dummy(a, b): - print(a @ b) - - def blah(a, b): - c = b + b - dummy(a, c) - - mmm = torch.randn(2, 3) - xxx = torch.randn(3) - # print(xxx@mmm) - blah(mmm, xxx) - blah(xxx, mmm) diff --git a/main.py b/main.py index 9e8589a..ed43b33 100755 --- a/main.py +++ b/main.py @@ -11,7 +11,7 @@ import torch, torchvision from torch import nn from torch.nn import functional as F -import ffutils, grids, attae +import grids, attae import threading, subprocess @@ -1009,6 +1009,11 @@ for n_epoch in range(current_epoch, args.nb_epochs): ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] + weakest_models = list( + filter( + lambda m: m.test_accuracy < args.accuracy_to_make_c_quizzes, weakest_models + ) + ) log_string( f"weakest_accuracies {[model.test_accuracy for model in weakest_models]}" -- 2.39.5