X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=9f825941bbba6d6cf00a4ea72cbbaef008cab7be;hb=128d372813e99d8474bb6e967d5c7e7f085c819d;hp=1d52b6defc3cd22ddea610e1fbfc071358354a47;hpb=cd3329fc206bacfd90a8e2cbe364244359568733;p=picoclvr.git diff --git a/main.py b/main.py index 1d52b6d..9f82594 100755 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, sys, argparse, time, tqdm, os +import math, sys, argparse, time, tqdm, os, datetime import torch, torchvision from torch import nn @@ -104,6 +104,8 @@ parser.add_argument("--rpl_no_prog", action="store_true", default=False) parser.add_argument("--grid_size", type=int, default=6) +parser.add_argument("--grid_fraction_play", type=float, default=0) + ############################## # picoclvr options @@ -257,9 +259,9 @@ default_task_args = { "nb_test_samples": 10000, }, "memory": { - "model": "4M", + "model": "37M", "batch_size": 100, - "nb_train_samples": 5000, + "nb_train_samples": 25000, "nb_test_samples": 1000, }, "mixing": { @@ -554,6 +556,7 @@ elif args.task == "grid": nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, size=args.grid_size, + fraction_play=args.grid_fraction_play, logger=log_string, device=device, ) @@ -718,6 +721,8 @@ if nb_epochs_finished >= nb_epochs: deterministic_synthesis=args.deterministic_synthesis, ) +time_pred_result = None + for n_epoch in range(nb_epochs_finished, nb_epochs): learning_rate = learning_rate_schedule[n_epoch] @@ -776,6 +781,13 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): deterministic_synthesis=args.deterministic_synthesis, ) + time_current_result = datetime.datetime.now() + if time_pred_result is not None: + log_string( + f"next_result {time_current_result + (time_current_result - time_pred_result)}" + ) + time_pred_result = time_current_result + checkpoint = { "nb_epochs_finished": n_epoch + 1, "model_state": model.state_dict(),