X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=74e1d6c9ea7882b91887adc91122f1c2a2d00464;hb=75e1ddcb8de30a4a7be16c80c4f258da662837a6;hp=fabebddcffc31b5c383f2cacde6b8016e03ea1c5;hpb=664758db86b059b68cd11e889a20cc9681e4324a;p=mygptrnn.git diff --git a/main.py b/main.py index fabebdd..74e1d6c 100755 --- a/main.py +++ b/main.py @@ -24,6 +24,17 @@ else: ###################################################################### + +def str2bool(x): + x = x.lower() + if x in {"1", "true", "yes"}: + return True + elif x in {"0", "false", "no"}: + return False + else: + raise ValueError + + parser = argparse.ArgumentParser( description="An implementation of GPT with cache.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -68,7 +79,7 @@ parser.add_argument("--min_learning_rate", type=float, default=6e-5) # legacy -parser.add_argument("--legacy_lr_schedule", action="store_true", default=False) +parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True) parser.add_argument("--legacy_large_lr", type=float, default=1e-4) @@ -96,8 +107,6 @@ parser.add_argument("--caterpillar_height", type=int, default=None) parser.add_argument("--rho", type=float, default=0.0) -parser.add_argument("--dim_rec_v", type=int, default=None) - parser.add_argument("--nb_blocks", type=int, default=None) parser.add_argument("--dropout", type=float, default=0.1) @@ -321,7 +330,6 @@ default_model_args = { "dim_keys": 32, "dim_hidden": 32, "nb_heads": 2, - "dim_rec_v": 16, "nb_blocks": 2, }, "17K-C": { @@ -332,7 +340,6 @@ default_model_args = { "nb_heads": 2, "nb_lines": 16, "caterpillar_height": 4, - "dim_rec_v": 16, "nb_blocks": 2, }, "4M": { @@ -341,7 +348,6 @@ default_model_args = { "dim_keys": 32, "dim_hidden": 1024, "nb_heads": 4, - "dim_rec_v": 64, "nb_blocks": 6, }, "4M-C": { @@ -352,7 +358,6 @@ default_model_args = { "nb_heads": 4, "nb_lines": 32, "caterpillar_height": 4, - "dim_rec_v": 64, # dim_model / nb_heads "nb_blocks": 6, }, "37M": { @@ -361,7 +366,6 @@ default_model_args = { "dim_keys": 64, "dim_hidden": 2048, "nb_heads": 8, - "dim_rec_v": 64, "nb_blocks": 12, }, "37M-C": { @@ -372,7 +376,6 @@ default_model_args = { "nb_heads": 8, "nb_lines": 256, "caterpillar_height": 32, - "dim_rec_v": 64, "nb_blocks": 12, }, "122M": { @@ -381,7 +384,6 @@ default_model_args = { "dim_keys": 64, "dim_hidden": 2048, "nb_heads": 8, - "dim_rec_v": 96, "nb_blocks": 24, }, "122M-C": { @@ -391,7 +393,6 @@ default_model_args = { "dim_hidden": 2048, "nb_heads": 8, "nb_lines": 128, - "dim_rec_v": 96, "nb_blocks": 24, }, "352M": { @@ -400,7 +401,6 @@ default_model_args = { "dim_keys": 64, "dim_hidden": 2048, "nb_heads": 8, - "dim_rec_v": 128, "nb_blocks": 48, }, "352M-C": { @@ -410,7 +410,6 @@ default_model_args = { "dim_hidden": 2048, "nb_heads": 8, "nb_lines": 128, - "dim_rec_v": 128, "nb_blocks": 48, }, } @@ -478,7 +477,7 @@ def get_lr(n_epoch, it): if it < args.nb_warmup_iter: return args.legacy_large_lr * it / args.nb_warmup_iter - elif it < args.legacy_nb_epoch_large_lr: + elif n_epoch < args.legacy_nb_epoch_large_lr: return args.legacy_large_lr else: return args.legacy_small_lr @@ -725,7 +724,6 @@ model = mygpt.MyGPT( nb_heads=args.nb_heads, nb_lines=args.nb_lines, caterpillar_height=args.caterpillar_height, - dim_rec_v=args.dim_rec_v, nb_blocks=args.nb_blocks, causal=True, dropout=args.dropout,