From e8e9b3941f150b20aa9585f7fa0a1f5e2fe6f547 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 21 Jun 2024 21:58:14 +0200 Subject: [PATCH] Update. --- main.py | 38 ++++++++++++++++++++------------------ world.py | 2 +- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index 4a1207d..61d77ed 100755 --- a/main.py +++ b/main.py @@ -32,7 +32,7 @@ parser = argparse.ArgumentParser( parser.add_argument( "--task", type=str, - default="twotargets", + default="world", help="file, byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp, greed", ) @@ -46,7 +46,7 @@ parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) ######################################## -parser.add_argument("--nb_epochs", type=int, default=50) +parser.add_argument("--nb_epochs", type=int, default=100) parser.add_argument("--batch_size", type=int, default=None) @@ -60,7 +60,7 @@ parser.add_argument("--optim", type=str, default="adam") parser.add_argument("--learning_rate", type=float, default=1e-4) -parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6") +parser.add_argument("--learning_rate_schedule", type=str, default=None) ######################################## @@ -374,9 +374,8 @@ else: try: os.mkdir(args.result_dir) except FileExistsError: - if not args.resume: - print(f"result directory {args.result_dir} already exists") - exit(1) + print(f"result directory {args.result_dir} already exists") + exit(1) log_file = open(os.path.join(args.result_dir, args.log_filename), "a") @@ -745,12 +744,15 @@ if args.learning_rate_schedule == "cos": u = n_epoch / args.nb_epochs * math.pi learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u)) else: - u = { - int(k): float(v) - for k, v in [ - tuple(x.split(":")) for x in args.learning_rate_schedule.split(",") - ] - } + if args.learning_rate_schedule is not None: + u = { + int(k): float(v) + for k, v in [ + tuple(x.split(":")) for x in args.learning_rate_schedule.split(",") + ] + } + else: + u = {} learning_rate_schedule = {} learning_rate = args.learning_rate @@ -890,19 +892,19 @@ def create_quizzes( ###################################################################### -accuracy_to_make_quizzes = 0.95 +accuracy_to_make_quizzes = 0.975 -for n_epoch in range(nb_epochs_finished, args.nb_epochs): +for n_epoch in range(args.nb_epochs): learning_rate = learning_rate_schedule[n_epoch] for m in models: one_epoch(m, task, learning_rate) test_accuracy = run_tests(m, task, deterministic_synthesis=False) - if test_accuracy >= accuracy_to_make_quizzes: - other_models = models.copy() - other_models.remove(model) - create_quizzes(other_models, task) + if test_accuracy >= accuracy_to_make_quizzes: + other_models = models.copy() + other_models.remove(m) + create_quizzes(m, other_models, task) # -------------------------------------------- diff --git a/world.py b/world.py index 43126d5..89833e6 100755 --- a/world.py +++ b/world.py @@ -22,7 +22,7 @@ colors = torch.tensor( [255, 0, 0], [0, 128, 0], [0, 0, 255], - [255, 255, 0], + [255, 200, 0], [192, 192, 192], ] ) -- 2.20.1