X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=90b4c6d6693f51a1ef35b54ed2eb8130865277c7;hb=4d3bc68c677cc9554df9c47dd214dfc4cb9c6577;hp=59133455353fc3c27c97466fc5595c0b8b9abbb0;hpb=cefdf80cffc5f897dc728d68bf927f522e3e1608;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index 5913345..90b4c6d 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -88,13 +88,16 @@ print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL) def log_string(s): global pred_log_t + t = time.time() if pred_log_t is None: elapsed = 'start' else: elapsed = '+{:.02f}s'.format(t - pred_log_t) + pred_log_t = t + s = Fore.BLUE + time.ctime() + ' ' + Fore.GREEN + elapsed + Style.RESET_ALL + ' ' + s log_file.write(s + '\n') log_file.flush() @@ -146,6 +149,8 @@ def train_model(model, train_set): optimizer = optim.SGD(model.parameters(), lr = 1e-2) + start_t = time.time() + for e in range(0, args.nb_epochs): acc_loss = 0.0 for b in range(0, train_set.nb_batches): @@ -157,6 +162,8 @@ def train_model(model, train_set): loss.backward() optimizer.step() log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss)) + dt = (time.time() - start_t) / (e + 1) + print(Fore.CYAN + 'ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + Style.RESET_ALL) return model