projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
main.py
diff --git
a/main.py
b/main.py
index
fd8ab41
..
9d95034
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-410,6
+410,10
@@
def create_c_quizzes(
nb_to_create = nb_for_train + nb_for_test
nb_to_create = nb_for_train + nb_for_test
+ warnings.warn(
+ f"{args.nb_gpts=} {args.nb_models_for_generation=} {args.min_to_validate=} {args.max_to_validate=}"
+ )
+
while nb_validated() < nb_to_create:
(
new_c_quizzes,
while nb_validated() < nb_to_create:
(
new_c_quizzes,
@@
-437,7
+441,8
@@
def create_c_quizzes(
for n in range(nb_correct.max() + 1):
recorded[n].append(new_c_quizzes[nb_correct == n].clone())
for n in range(nb_correct.max() + 1):
recorded[n].append(new_c_quizzes[nb_correct == n].clone())
- nv = [recorded[n][-1].size(0) for n in recorded.keys()]
+ nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
+ nv = " ".join([str(x.item()) for x in nv])
log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}")
log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}")