parser.add_argument("--log_command", type=str, default=None)
# ----------------------------------
+
parser.add_argument("--nb_epochs", type=int, default=10000)
parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument("--learning_rate", type=float, default=5e-4)
+parser.add_argument("--schedule_free", action="store_true", default=False)
+
# ----------------------------------
parser.add_argument("--model", type=str, default=None)
def run_tests(model, quiz_machine, local_device=main_device):
with torch.autograd.no_grad():
- model.eval().to(local_device)
+ model.to(local_device).eval()
+ model.optimizer.eval()
nb_test_samples, acc_test_loss = 0, 0.0
nb_samples_accumulated = 0
def one_epoch(model, quiz_machine, local_device=main_device):
model.to(local_device).train()
+ model.optimizer.train()
nb_train_samples, acc_train_loss = 0, 0.0
return t_q < t_k
+if args.schedule_free:
+ import schedulefree
+
for k in range(args.nb_gpts):
log_string(f"creating model {k} and its w_quizzes")
model.id = k
- model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+ if args.schedule_free:
+ model.optimizer = schedulefree.AdamWScheduleFree(
+ model.parameters(), lr=args.learning_rate
+ )
+ else:
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
model.main_test_accuracy = 0.0
model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes(