X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=69731ff89e5e290b2124d84eeffad5aafcabef03;hb=ac3d9ba45d72a7f3e399de4e3614698ac5e0ce39;hp=1d52b6defc3cd22ddea610e1fbfc071358354a47;hpb=cd3329fc206bacfd90a8e2cbe364244359568733;p=picoclvr.git diff --git a/main.py b/main.py index 1d52b6d..69731ff 100755 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, sys, argparse, time, tqdm, os +import math, sys, argparse, time, tqdm, os, datetime import torch, torchvision from torch import nn @@ -257,9 +257,9 @@ default_task_args = { "nb_test_samples": 10000, }, "memory": { - "model": "4M", + "model": "37M", "batch_size": 100, - "nb_train_samples": 5000, + "nb_train_samples": 25000, "nb_test_samples": 1000, }, "mixing": { @@ -718,6 +718,8 @@ if nb_epochs_finished >= nb_epochs: deterministic_synthesis=args.deterministic_synthesis, ) +time_pred_result = None + for n_epoch in range(nb_epochs_finished, nb_epochs): learning_rate = learning_rate_schedule[n_epoch] @@ -776,6 +778,13 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): deterministic_synthesis=args.deterministic_synthesis, ) + time_current_result = datetime.datetime.now() + if time_pred_result is not None: + log_string( + f"next_result {time_current_result + (time_current_result - time_pred_result)}" + ) + time_pred_result = time_current_result + checkpoint = { "nb_epochs_finished": n_epoch + 1, "model_state": model.state_dict(),