X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=9d950344d1579d1e1678b075ef48053f39689962;hb=6b4e192557e03528ffd10364123de454aa9c9f08;hp=fd8ab4191e41b27199ce793c411e5cf346385c1b;hpb=51540cefc448684d5086297d23e9a1805da4d405;p=culture.git diff --git a/main.py b/main.py index fd8ab41..9d95034 100755 --- a/main.py +++ b/main.py @@ -410,6 +410,10 @@ def create_c_quizzes( nb_to_create = nb_for_train + nb_for_test + warnings.warn( + f"{args.nb_gpts=} {args.nb_models_for_generation=} {args.min_to_validate=} {args.max_to_validate=}" + ) + while nb_validated() < nb_to_create: ( new_c_quizzes, @@ -437,7 +441,8 @@ def create_c_quizzes( for n in range(nb_correct.max() + 1): recorded[n].append(new_c_quizzes[nb_correct == n].clone()) - nv = [recorded[n][-1].size(0) for n in recorded.keys()] + nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0) + nv = " ".join([str(x.item()) for x in nv]) log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}")