X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=04e56527a51935496c69dbbb6124bd18bc160f28;hb=e56873a0cb64555cbd47e44cdca0ce991765a5fc;hp=3e67a73fab1452f0d1f97aff1f890fe3ab70bcab;hpb=3dd98b99909b2bca323673263874e2abb39ac10c;p=mygptrnn.git diff --git a/main.py b/main.py index 3e67a73..04e5652 100755 --- a/main.py +++ b/main.py @@ -835,6 +835,24 @@ if args.max_percents_of_test_in_train >= 0: ############################## +for input in task.batches(split="train", desc="calibrate"): + input = input.to(device) + output = model(mygpt.BracketedSequence(input)).x + +for n, m in model.named_modules(): + for a in dir(m): + x = getattr(m, a) + if isinstance(x, mygpt.Calibrator): + print(f"####### ${n} | ${a} ########################") + mean, std = x.moments() + print("mean\n", mean, "\n") + print("std\n", std, "\n") + print(f"############################################\n\n") + +exit(0) + +############################## + nb_samples_seen = 0 if nb_epochs_finished >= nb_epochs: