X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=13fbe8e34adb0d3f06dd75af2a16bcf805b5d949;hb=91e75aba9a2a250985843bbf0ccb81f39dd97ce4;hp=970ee7b528f0f3f91bd777b760b48140c343f678;hpb=68c17359790a9b8ac931a3679f08ad6a82a4e640;p=mygpt.git diff --git a/mygpt.py b/mygpt.py index 970ee7b..13fbe8e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -252,6 +252,7 @@ class TaskPicoCLVR(Task): def produce_results(self, n_epoch, model, nb_tokens = 50): img = [ ] nb_per_primer = 8 + for primer in [ 'red above green green top blue right of red ', 'there is red there is yellow there is blue ', @@ -507,7 +508,10 @@ for k in range(args.nb_epochs): acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) - log_string(f'perplexity {k+1} train {math.exp(min(100, acc_train_loss/nb_train_samples))} test {math.exp(min(100, acc_test_loss/nb_test_samples))}') + train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) + test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) + + log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}') task.produce_results(k, model)