From 70d2428cd6e4caaf5c81c6cb77961866405a4cd5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jun 2024 08:54:11 +0200 Subject: [PATCH] Update. --- main.py | 44 +++----------------------------------------- tasks.py | 5 ++++- 2 files changed, 7 insertions(+), 42 deletions(-) diff --git a/main.py b/main.py index 11d712a..6c27599 100755 --- a/main.py +++ b/main.py @@ -46,7 +46,7 @@ parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) ######################################## -parser.add_argument("--nb_epochs", type=int, default=100) +parser.add_argument("--nb_epochs", type=int, default=10000) parser.add_argument("--batch_size", type=int, default=None) @@ -56,12 +56,8 @@ parser.add_argument("--nb_train_samples", type=int, default=None) parser.add_argument("--nb_test_samples", type=int, default=None) -parser.add_argument("--optim", type=str, default="adam") - parser.add_argument("--learning_rate", type=float, default=1e-4) -parser.add_argument("--learning_rate_schedule", type=str, default=None) - ######################################## parser.add_argument("--model", type=str, default=None) @@ -716,43 +712,9 @@ if args.max_percents_of_test_in_train >= 0: ############################## -if args.learning_rate_schedule == "cos": - learning_rate_schedule = {} - for n_epoch in range(args.nb_epochs): - u = n_epoch / args.nb_epochs * math.pi - learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u)) -else: - if args.learning_rate_schedule is not None: - u = { - int(k): float(v) - for k, v in [ - tuple(x.split(":")) for x in args.learning_rate_schedule.split(",") - ] - } - else: - u = {} - - learning_rate_schedule = {} - learning_rate = args.learning_rate - for n_epoch in range(args.nb_epochs): - if n_epoch in u: - learning_rate = u[n_epoch] - learning_rate_schedule[n_epoch] = learning_rate - -log_string(f"learning_rate_schedule {learning_rate_schedule}") - -###################################################################### - def one_epoch(model, task): - if args.optim == "sgd": - optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) - elif args.optim == "adam": - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - elif args.optim == "adamw": - optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) - else: - raise ValueError(f"Unknown optimizer {args.optim}.") + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) model.train() @@ -851,7 +813,7 @@ def create_quizzes( task.save_image( new_quizzes[:96], args.result_dir, - f"world_new_{n_epoch:04d}.png", + f"world_new_{n_epoch:04d}_{model.id:02d}.png", log_string, ) diff --git a/tasks.py b/tasks.py index b4829d9..5d9a018 100755 --- a/tasks.py +++ b/tasks.py @@ -2236,7 +2236,10 @@ class World(Task): ) self.save_image( - result[:96], result_dir, f"world_result_{n_epoch:04d}.png", logger + result[:96], + result_dir, + f"world_result_{n_epoch:04d}_{model.id:02d}.png", + logger, ) return main_test_accuracy -- 2.39.5