Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 21 Jun 2024 06:51:00 +0000 (08:51 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 21 Jun 2024 06:51:00 +0000 (08:51 +0200)
main.py

diff --git a/main.py b/main.py
index 5234d6f..d92c4a5 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -815,9 +815,10 @@ if nb_epochs_finished >= args.nb_epochs:
 
 time_pred_result = None
 
-for n_epoch in range(nb_epochs_finished, args.nb_epochs):
-    learning_rate = learning_rate_schedule[n_epoch]
+######################################################################
+
 
+def one_epoch(model, task, learning_rate):
     log_string(f"learning_rate {learning_rate}")
 
     if args.optim == "sgd":
@@ -850,6 +851,15 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
         if nb_train_samples % args.batch_size == 0:
             optimizer.step()
 
+    train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+    log_string(f"train)perplexity {n_epoch} {train_perplexity}")
+
+
+######################################################################
+
+
+def run_tests(model, task):
     with torch.autograd.no_grad():
         model.eval()
 
@@ -868,13 +878,6 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
 
             nb_test_samples += input.size(0)
 
-        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
-
-        log_string(
-            f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
-        )
-
         task.produce_results(
             n_epoch=n_epoch,
             model=model,
@@ -883,12 +886,25 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
             deterministic_synthesis=args.deterministic_synthesis,
         )
 
-        time_current_result = datetime.datetime.now()
-        if time_pred_result is not None:
-            log_string(
-                f"next_result {time_current_result + (time_current_result - time_pred_result)}"
-            )
-        time_pred_result = time_current_result
+        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+        log_string(f"test)perplexity {n_epoch} {test_perplexity}")
+
+
+######################################################################
+
+for n_epoch in range(nb_epochs_finished, args.nb_epochs):
+    learning_rate = learning_rate_schedule[n_epoch]
+
+    one_epoch(model, task, learning_rate)
+
+    run_tests(model, task)
+
+    time_current_result = datetime.datetime.now()
+    if time_pred_result is not None:
+        log_string(
+            f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+        )
+    time_pred_result = time_current_result
 
     checkpoint = {
         "nb_epochs_finished": n_epoch + 1,