X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=mygpt.py;h=13fbe8e34adb0d3f06dd75af2a16bcf805b5d949;hb=91e75aba9a2a250985843bbf0ccb81f39dd97ce4;hp=e6387bdc33a7a6c68bd702f44949dd7c1fec32e5;hpb=0368fffb6da510a7fd8a3070dd3df53476913630;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index e6387bd..13fbe8e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -508,7 +508,10 @@ for k in range(args.nb_epochs): acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) - log_string(f'perplexity {k+1} train {math.exp(min(100, acc_train_loss/nb_train_samples))} test {math.exp(min(100, acc_test_loss/nb_test_samples))}') + train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) + test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) + + log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}') task.produce_results(k, model)