From: Francois Fleuret Date: Tue, 30 Jun 2020 09:06:04 +0000 (+0200) Subject: Added the generation of dd-multi-mse.pdf X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=437a0746551145f241b39d4a95ae28ecd1410a54;p=pytorch.git Added the generation of dd-multi-mse.pdf --- diff --git a/ddpol.py b/ddpol.py index 9d14d2a..6ace38d 100755 --- a/ddpol.py +++ b/ddpol.py @@ -111,6 +111,31 @@ fig.savefig('dd-mse.pdf', bbox_inches='tight') plt.close(fig) +###################################################################### +# Plot multiple MSE vs. degree curves + +fig = plt.figure() + +ax = fig.add_subplot(1, 1, 1) +ax.set_yscale('log') +ax.set_ylim(1e-5, 1) +ax.set_xlabel('Polynomial degree', labelpad = 10) +ax.set_ylabel('MSE', labelpad = 10) + +nb_train_samples_min = args.nb_train_samples - 4 +nb_train_samples_max = args.nb_train_samples + +for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2): + mse_train, mse_test = compute_mse(nb_train_samples) + e = float(nb_train_samples - nb_train_samples_min) / float(nb_train_samples_max - nb_train_samples_min) + e = 0.15 + 0.7 * e + ax.plot(torch.arange(args.D_max + 1), mse_train, color = (e, e, 1.0), label = f'Train N={nb_train_samples}') + ax.plot(torch.arange(args.D_max + 1), mse_test, color = (1.0, e, e), label = f'Test N={nb_train_samples}') + +fig.savefig('dd-multi-mse.pdf', bbox_inches='tight') + +plt.close(fig) + ###################################################################### # Plot some examples of train / test