X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=main.py;h=04e56527a51935496c69dbbb6124bd18bc160f28;hb=e56873a0cb64555cbd47e44cdca0ce991765a5fc;hp=969b47f7ac144870b0598514e6593a64e52daee8;hpb=cebc20b3608a41bfd27b2ab9d950c082f9b7ea89;p=mygptrnn.git diff --git a/main.py b/main.py index 969b47f..04e5652 100755 --- a/main.py +++ b/main.py @@ -202,9 +202,11 @@ parser.add_argument("--mixing_deterministic_start", action="store_true", default ###################################################################### -args = parser.parse_args() +# args = parser.parse_args() -assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} +args, sup_args = parser.parse_known_args() + +sup_args = dict([x.removeprefix("--").split("=") for x in sup_args]) if args.result_dir is None: args.result_dir = f"results_{args.task}_{args.model}" @@ -432,6 +434,8 @@ except FileExistsError: print(f"result directory {args.result_dir} already exists") exit(1) +loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a") + log_file = open(os.path.join(args.result_dir, args.log_filename), "a") if args.seed >= 0: @@ -468,6 +472,9 @@ log_string(f"argv {' '.join(sys.argv)}") for n in vars(args): log_string(f"args.{n} {getattr(args, n)}") +for k, v in sup_args.items(): + log_string(f'sup_args["{k}"] "{v}"') + ###################################################################### @@ -505,6 +512,9 @@ def get_lr(n_epoch, it): ###################################################################### +assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} + + def picoclvr_pruner_horizontal_green(p): return not ("green" in p and ("left" in p or "right" in p)) @@ -730,6 +740,8 @@ model = mygpt.MyGPT( causal=True, dropout=args.dropout, attention_layer=args.attention, + logger=log_string, + **sup_args, ) model.to(device) @@ -823,6 +835,24 @@ if args.max_percents_of_test_in_train >= 0: ############################## +for input in task.batches(split="train", desc="calibrate"): + input = input.to(device) + output = model(mygpt.BracketedSequence(input)).x + +for n, m in model.named_modules(): + for a in dir(m): + x = getattr(m, a) + if isinstance(x, mygpt.Calibrator): + print(f"####### ${n} | ${a} ########################") + mean, std = x.moments() + print("mean\n", mean, "\n") + print("std\n", std, "\n") + print(f"############################################\n\n") + +exit(0) + +############################## + nb_samples_seen = 0 if nb_epochs_finished >= nb_epochs: @@ -838,6 +868,8 @@ time_pred_result = datetime.datetime.now() it = 0 +n_batch = 0 + for n_epoch in range(nb_epochs_finished, nb_epochs): if args.optim == "sgd": optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) @@ -879,6 +911,12 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): total_loss.backward() optimizer.step() + grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt() + + loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n") + + n_batch += 1 + with torch.autograd.no_grad(): model.eval()