X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=969b47f7ac144870b0598514e6593a64e52daee8;hb=cebc20b3608a41bfd27b2ab9d950c082f9b7ea89;hp=cae20f89347b076226b6525d561dab1d89a55e8b;hpb=6183291906184569c2206c34588d118cc77f74bb;p=mygptrnn.git diff --git a/main.py b/main.py index cae20f8..969b47f 100755 --- a/main.py +++ b/main.py @@ -16,14 +16,6 @@ import mygpt, tasks, problems ###################################################################### -if torch.cuda.is_available(): - device = torch.device("cuda") - torch.backends.cuda.matmul.allow_tf32 = True -else: - device = torch.device("cpu") - -###################################################################### - def str2bool(x): x = x.lower() @@ -55,6 +47,8 @@ parser.add_argument("--seed", type=int, default=0) parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) +parser.add_argument("--force_cpu", type=str2bool, default=False) + ######################################## parser.add_argument("--nb_epochs", type=int, default=50) @@ -107,8 +101,6 @@ parser.add_argument("--caterpillar_height", type=int, default=None) parser.add_argument("--rho", type=float, default=0.0) -parser.add_argument("--dim_rec_v", type=int, default=None) - parser.add_argument("--nb_blocks", type=int, default=None) parser.add_argument("--dropout", type=float, default=0.1) @@ -119,7 +111,7 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa parser.add_argument("--no_checkpoint", action="store_true", default=False) -parser.add_argument("--overwrite_results", action="store_true", default=False) +parser.add_argument("--continue_training", action="store_true", default=False) parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") @@ -219,6 +211,14 @@ if args.result_dir is None: ###################################################################### +if not args.force_cpu and torch.cuda.is_available(): + device = torch.device("cuda") + torch.backends.cuda.matmul.allow_tf32 = True +else: + device = torch.device("cpu") + +###################################################################### + default_task_args = { "addition": { "model": "352M", @@ -332,7 +332,6 @@ default_model_args = { "dim_keys": 32, "dim_hidden": 32, "nb_heads": 2, - "dim_rec_v": 16, "nb_blocks": 2, }, "17K-C": { @@ -343,7 +342,6 @@ default_model_args = { "nb_heads": 2, "nb_lines": 16, "caterpillar_height": 4, - "dim_rec_v": 16, "nb_blocks": 2, }, "4M": { @@ -352,7 +350,6 @@ default_model_args = { "dim_keys": 32, "dim_hidden": 1024, "nb_heads": 4, - "dim_rec_v": 64, "nb_blocks": 6, }, "4M-C": { @@ -363,7 +360,6 @@ default_model_args = { "nb_heads": 4, "nb_lines": 32, "caterpillar_height": 4, - "dim_rec_v": 64, # dim_model / nb_heads "nb_blocks": 6, }, "37M": { @@ -372,7 +368,6 @@ default_model_args = { "dim_keys": 64, "dim_hidden": 2048, "nb_heads": 8, - "dim_rec_v": 64, "nb_blocks": 12, }, "37M-C": { @@ -383,7 +378,6 @@ default_model_args = { "nb_heads": 8, "nb_lines": 256, "caterpillar_height": 32, - "dim_rec_v": 64, "nb_blocks": 12, }, "122M": { @@ -392,7 +386,6 @@ default_model_args = { "dim_keys": 64, "dim_hidden": 2048, "nb_heads": 8, - "dim_rec_v": 96, "nb_blocks": 24, }, "122M-C": { @@ -402,7 +395,6 @@ default_model_args = { "dim_hidden": 2048, "nb_heads": 8, "nb_lines": 128, - "dim_rec_v": 96, "nb_blocks": 24, }, "352M": { @@ -411,7 +403,6 @@ default_model_args = { "dim_keys": 64, "dim_hidden": 2048, "nb_heads": 8, - "dim_rec_v": 128, "nb_blocks": 48, }, "352M-C": { @@ -421,7 +412,6 @@ default_model_args = { "dim_hidden": 2048, "nb_heads": 8, "nb_lines": 128, - "dim_rec_v": 128, "nb_blocks": 48, }, } @@ -438,7 +428,7 @@ else: try: os.mkdir(args.result_dir) except FileExistsError: - if not args.overwrite_results: + if not args.continue_training: print(f"result directory {args.result_dir} already exists") exit(1) @@ -736,7 +726,6 @@ model = mygpt.MyGPT( nb_heads=args.nb_heads, nb_lines=args.nb_lines, caterpillar_height=args.caterpillar_height, - dim_rec_v=args.dim_rec_v, nb_blocks=args.nb_blocks, causal=True, dropout=args.dropout, @@ -845,7 +834,7 @@ if nb_epochs_finished >= nb_epochs: deterministic_synthesis=args.deterministic_synthesis, ) -time_pred_result = None +time_pred_result = datetime.datetime.now() it = 0 @@ -923,10 +912,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): ) time_current_result = datetime.datetime.now() - if time_pred_result is not None: - log_string( - f"next_result {time_current_result + (time_current_result - time_pred_result)}" - ) + log_string( + f"next_result {time_current_result + (time_current_result - time_pred_result)}" + ) time_pred_result = time_current_result checkpoint = {