projects
/
culture.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
df414ba
)
Update.
author
François Fleuret
<francois@fleuret.org>
Sat, 13 Jul 2024 05:21:40 +0000
(07:21 +0200)
committer
François Fleuret
<francois@fleuret.org>
Sat, 13 Jul 2024 05:21:40 +0000
(07:21 +0200)
main.py
patch
|
blob
|
history
diff --git
a/main.py
b/main.py
index
a8ceac8
..
9599cf3
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-78,10
+78,6
@@
parser.add_argument("--gpus", type=str, default="all")
parser.add_argument("--nb_gpts", type=int, default=5)
parser.add_argument("--nb_gpts", type=int, default=5)
-parser.add_argument("--min_to_validate", type=int, default=None)
-
-parser.add_argument("--max_to_validate", type=int, default=None)
-
parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
parser.add_argument("--proba_understands", type=float, default=0.99)
parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
parser.add_argument("--proba_understands", type=float, default=0.99)
@@
-121,12
+117,6
@@
parser.add_argument("--sky_speed", type=int, default=3)
args = parser.parse_args()
args = parser.parse_args()
-if args.min_to_validate is None:
- args.min_to_validate = args.nb_gpts - 1
-
-if args.max_to_validate is None:
- args.max_to_validate = args.nb_gpts - 1
-
if args.result_dir is None:
args.result_dir = f"results_culture"
if args.result_dir is None:
args.result_dir = f"results_culture"
@@
-338,10
+328,10
@@
def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_de
def one_epoch(model, quiz_machine, local_device=main_device):
def one_epoch(model, quiz_machine, local_device=main_device):
- optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
model.to(local_device).train()
model.to(local_device).train()
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
nb_train_samples, acc_train_loss = 0, 0.0
for input in quiz_machine.batches(model, split="train"):
nb_train_samples, acc_train_loss = 0, 0.0
for input in quiz_machine.batches(model, split="train"):