From a2b35c224e66f7e17612c0e8de2462c9e998e051 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 8 Aug 2024 20:01:16 +0200 Subject: [PATCH] Update. --- main.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 8bca425..3196fbd 100755 --- a/main.py +++ b/main.py @@ -363,10 +363,29 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### +def optimizer_to(optim, device): + for param in optim.state.values(): + # Not sure there are any global tensors in the state dict + if isinstance(param, torch.Tensor): + param.data = param.data.to(device) + if param._grad is not None: + param._grad.data = param._grad.data.to(device) + elif isinstance(param, dict): + for subparam in param.values(): + if isinstance(subparam, torch.Tensor): + subparam.data = subparam.data.to(device) + if subparam._grad is not None: + subparam._grad.data = subparam._grad.data.to(device) + + +###################################################################### + + def run_tests(model, quiz_machine, local_device=main_device): with torch.autograd.no_grad(): model.to(local_device).eval() - model.optimizer.eval() + if args.schedule_free: + model.optimizer.eval() nb_test_samples, acc_test_loss = 0, 0.0 nb_samples_accumulated = 0 @@ -398,7 +417,10 @@ def run_tests(model, quiz_machine, local_device=main_device): def one_epoch(model, quiz_machine, local_device=main_device): model.to(local_device).train() - model.optimizer.train() + optimizer_to(model.optimizer, local_device) + + if args.schedule_free: + model.optimizer.train() nb_train_samples, acc_train_loss = 0, 0.0 @@ -454,6 +476,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): # ) model.to(main_device) + optimizer_to(model.optimizer, main_device) ###################################################################### -- 2.39.5