X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=67c57c0eaa3e316563e524e75380460e67348968;hb=60bf08d4197f2dd3a58bd900401c11d47225b0df;hp=fd8ab4191e41b27199ce793c411e5cf346385c1b;hpb=51540cefc448684d5086297d23e9a1805da4d405;p=culture.git diff --git a/main.py b/main.py index fd8ab41..67c57c0 100755 --- a/main.py +++ b/main.py @@ -437,7 +437,8 @@ def create_c_quizzes( for n in range(nb_correct.max() + 1): recorded[n].append(new_c_quizzes[nb_correct == n].clone()) - nv = [recorded[n][-1].size(0) for n in recorded.keys()] + nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0) + nv = " ".join([str(x.item()) for x in nv]) log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}")