X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=main.py;h=55379652a8766a72fa8b20035d4176702f791251;hb=525bd24014786b53638dea78cfb88035a2b99d97;hp=d7fb3d1119be0f94ea6ed18c7db39af6b03c6ac9;hpb=bfcef9a8c82ed45528601e85725166241bbee916;p=culture.git diff --git a/main.py b/main.py index d7fb3d1..5537965 100755 --- a/main.py +++ b/main.py @@ -13,7 +13,7 @@ from torch.nn import functional as F import ffutils import mygpt -import sky, quizz_machine +import sky, wireworld, quizz_machine # world quizzes vs. culture quizzes @@ -37,7 +37,7 @@ parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) -parser.add_argument("--log_filename", type=str, default="train.log", help=" ") +parser.add_argument("--log_filename", type=str, default="train.log") parser.add_argument("--result_dir", type=str, default=None) @@ -79,6 +79,8 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) +parser.add_argument("--problem", type=str, default="sky") + parser.add_argument("--nb_gpts", type=int, default=5) parser.add_argument("--nb_models_for_generation", type=int, default=1) @@ -219,8 +221,15 @@ else: assert args.nb_train_samples % args.batch_size == 0 assert args.nb_test_samples % args.batch_size == 0 +if args.problem == "sky": + problem = (sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2),) +elif args.problem == "wireworld": + problem = wireworld.Wireworld(height=10, width=15, nb_iterations=4) +else: + raise ValueError + quizz_machine = quizz_machine.QuizzMachine( - problem=sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2), + problem=problem, nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.physical_batch_size, @@ -343,7 +352,6 @@ def run_tests(model, quizz_machine, deterministic_synthesis): n_epoch=n_epoch, model=model, result_dir=args.result_dir, - logger=log_string, deterministic_synthesis=deterministic_synthesis, ) @@ -397,7 +405,6 @@ def create_c_quizzes( min_ave_seq_logproba=min_ave_seq_logproba, n_epoch=n_epoch, result_dir=args.result_dir, - logger=log_string, ) sum_logits += new_c_quizzes.size(0) * ave_seq_logproba @@ -487,7 +494,8 @@ for n_epoch in range(args.nb_epochs): a = [(model.id, float(model.main_test_accuracy)) for model in models] a.sort(key=lambda p: p[0]) - log_string(f"current accuracies {a}") + s = " ".join([f"{p[1]*100:.02f}%" for p in a]) + log_string(f"current accuracies {s}") # select the model with lowest accuracy models.sort(key=lambda model: model.main_test_accuracy)