Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 1d52b6d..9f82594 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, sys, argparse, time, tqdm, os
+import math, sys, argparse, time, tqdm, os, datetime
 
 import torch, torchvision
 from torch import nn
@@ -104,6 +104,8 @@ parser.add_argument("--rpl_no_prog", action="store_true", default=False)
 
 parser.add_argument("--grid_size", type=int, default=6)
 
+parser.add_argument("--grid_fraction_play", type=float, default=0)
+
 ##############################
 # picoclvr options
 
@@ -257,9 +259,9 @@ default_task_args = {
         "nb_test_samples": 10000,
     },
     "memory": {
-        "model": "4M",
+        "model": "37M",
         "batch_size": 100,
-        "nb_train_samples": 5000,
+        "nb_train_samples": 25000,
         "nb_test_samples": 1000,
     },
     "mixing": {
@@ -554,6 +556,7 @@ elif args.task == "grid":
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
         size=args.grid_size,
+        fraction_play=args.grid_fraction_play,
         logger=log_string,
         device=device,
     )
@@ -718,6 +721,8 @@ if nb_epochs_finished >= nb_epochs:
         deterministic_synthesis=args.deterministic_synthesis,
     )
 
+time_pred_result = None
+
 for n_epoch in range(nb_epochs_finished, nb_epochs):
     learning_rate = learning_rate_schedule[n_epoch]
 
@@ -776,6 +781,13 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
             deterministic_synthesis=args.deterministic_synthesis,
         )
 
+        time_current_result = datetime.datetime.now()
+        if time_pred_result is not None:
+            log_string(
+                f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+            )
+        time_pred_result = time_current_result
+
     checkpoint = {
         "nb_epochs_finished": n_epoch + 1,
         "model_state": model.state_dict(),