X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=2edfa14de0107376a431632791564451e729c298;hb=b1d28a1ed672be21947509dac2f90666b65b5034;hp=958dfc70a791d2b47708b10e7f44e53e7456fc2e;hpb=8ea809c43242d3a2e063692105919a86c3f6fe6b;p=picoclvr.git diff --git a/main.py b/main.py index 958dfc7..2edfa14 100755 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ parser.add_argument( "--task", type=str, default="twotargets", - help="file, byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp", + help="file, byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp, escape", ) parser.add_argument("--log_filename", type=str, default="train.log", help=" ") @@ -175,6 +175,15 @@ parser.add_argument("--mixing_hard", action="store_true", default=False) parser.add_argument("--mixing_deterministic_start", action="store_true", default=False) +############################## +# escape options + +parser.add_argument("--escape_height", type=int, default=4) + +parser.add_argument("--escape_width", type=int, default=6) + +parser.add_argument("--escape_T", type=int, default=25) + ###################################################################### args = parser.parse_args() @@ -289,6 +298,12 @@ default_task_args = { "nb_train_samples": 60000, "nb_test_samples": 10000, }, + "escape": { + "model": "37M", + "batch_size": 25, + "nb_train_samples": 25000, + "nb_test_samples": 10000, + }, } if args.task in default_task_args: @@ -599,6 +614,18 @@ elif args.task == "qmlp": device=device, ) +elif args.task == "escape": + task = tasks.Escape( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + height=args.escape_height, + width=args.escape_width, + T=args.escape_T, + logger=log_string, + device=device, + ) + else: raise ValueError(f"Unknown task {args.task}")