######################################################################
+def optimizer_to(optim, device):
+ for param in optim.state.values():
+ # Not sure there are any global tensors in the state dict
+ if isinstance(param, torch.Tensor):
+ param.data = param.data.to(device)
+ if param._grad is not None:
+ param._grad.data = param._grad.data.to(device)
+ elif isinstance(param, dict):
+ for subparam in param.values():
+ if isinstance(subparam, torch.Tensor):
+ subparam.data = subparam.data.to(device)
+ if subparam._grad is not None:
+ subparam._grad.data = subparam._grad.data.to(device)
+
+
+######################################################################
+
+
def run_tests(model, quiz_machine, local_device=main_device):
with torch.autograd.no_grad():
model.to(local_device).eval()
- model.optimizer.eval()
+ if args.schedule_free:
+ model.optimizer.eval()
nb_test_samples, acc_test_loss = 0, 0.0
nb_samples_accumulated = 0
def one_epoch(model, quiz_machine, local_device=main_device):
model.to(local_device).train()
- model.optimizer.train()
+ optimizer_to(model.optimizer, local_device)
+
+ if args.schedule_free:
+ model.optimizer.train()
nb_train_samples, acc_train_loss = 0, 0.0
# )
model.to(main_device)
+ optimizer_to(model.optimizer, main_device)
######################################################################