X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=79841f3e518b1854bf3dd21034e3a49096fff49c;hb=64dc96ddfa84511ba07d1929481e93e864735409;hp=04e56527a51935496c69dbbb6124bd18bc160f28;hpb=e56873a0cb64555cbd47e44cdca0ce991765a5fc;p=mygptrnn.git diff --git a/main.py b/main.py index 04e5652..79841f3 100755 --- a/main.py +++ b/main.py @@ -133,6 +133,10 @@ parser.add_argument("--rpl_no_prog", action="store_true", default=False) parser.add_argument("--grid_size", type=int, default=6) +parser.add_argument("--grid_nb_colors", type=int, default=6) + +parser.add_argument("--grid_nb_shapes", type=int, default=6) + ############################## # picoclvr options @@ -701,6 +705,8 @@ elif args.task == "grid": nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, size=args.grid_size, + nb_shapes=args.grid_nb_shapes, + nb_colors=args.grid_nb_colors, logger=log_string, device=device_data, ) @@ -835,21 +841,22 @@ if args.max_percents_of_test_in_train >= 0: ############################## -for input in task.batches(split="train", desc="calibrate"): - input = input.to(device) - output = model(mygpt.BracketedSequence(input)).x - -for n, m in model.named_modules(): - for a in dir(m): - x = getattr(m, a) - if isinstance(x, mygpt.Calibrator): - print(f"####### ${n} | ${a} ########################") - mean, std = x.moments() - print("mean\n", mean, "\n") - print("std\n", std, "\n") - print(f"############################################\n\n") - -exit(0) +if "calibrate" in sup_args: + for input in task.batches(split="train", desc="calibrate"): + input = input.to(device) + output = model(mygpt.BracketedSequence(input)).x + + for n, m in model.named_modules(): + for a in dir(m): + x = getattr(m, a) + if isinstance(x, mygpt.Calibrator): + print(f"####### ${n} | ${a} ########################") + mean, std = x.moments() + print("mean\n", mean, "\n") + print("std\n", std, "\n") + print(f"############################################\n\n") + + exit(0) ##############################