Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 8 Aug 2024 18:01:16 +0000 (20:01 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 8 Aug 2024 18:01:16 +0000 (20:01 +0200)
main.py

diff --git a/main.py b/main.py
index 8bca425..3196fbd 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -363,10 +363,29 @@ log_string(f"vocabulary_size {vocabulary_size}")
 ######################################################################
 
 
+def optimizer_to(optim, device):
+    for param in optim.state.values():
+        # Not sure there are any global tensors in the state dict
+        if isinstance(param, torch.Tensor):
+            param.data = param.data.to(device)
+            if param._grad is not None:
+                param._grad.data = param._grad.data.to(device)
+        elif isinstance(param, dict):
+            for subparam in param.values():
+                if isinstance(subparam, torch.Tensor):
+                    subparam.data = subparam.data.to(device)
+                    if subparam._grad is not None:
+                        subparam._grad.data = subparam._grad.data.to(device)
+
+
+######################################################################
+
+
 def run_tests(model, quiz_machine, local_device=main_device):
     with torch.autograd.no_grad():
         model.to(local_device).eval()
-        model.optimizer.eval()
+        if args.schedule_free:
+            model.optimizer.eval()
 
         nb_test_samples, acc_test_loss = 0, 0.0
         nb_samples_accumulated = 0
@@ -398,7 +417,10 @@ def run_tests(model, quiz_machine, local_device=main_device):
 
 def one_epoch(model, quiz_machine, local_device=main_device):
     model.to(local_device).train()
-    model.optimizer.train()
+    optimizer_to(model.optimizer, local_device)
+
+    if args.schedule_free:
+        model.optimizer.train()
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
@@ -454,6 +476,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
     # )
 
     model.to(main_device)
+    optimizer_to(model.optimizer, main_device)
 
 
 ######################################################################