+######################################################################
+# Plot multiple MSE vs. degree curves
+
+fig = plt.figure()
+
+ax = fig.add_subplot(1, 1, 1)
+ax.set_yscale("log")
+ax.set_ylim(1e-5, 1)
+ax.set_xlabel("Polynomial degree", labelpad=10)
+ax.set_ylabel("MSE", labelpad=10)
+
+nb_train_samples_min = args.nb_train_samples - 4
+nb_train_samples_max = args.nb_train_samples
+
+for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2):
+ mse_train, mse_test = compute_mse(nb_train_samples)
+ e = float(nb_train_samples - nb_train_samples_min) / float(
+ nb_train_samples_max - nb_train_samples_min
+ )
+ e = 0.15 + 0.7 * e
+ ax.plot(
+ torch.arange(args.D_max + 1),
+ mse_train,
+ color=(e, e, 1.0),
+ label=f"Train N={nb_train_samples}",
+ )
+ ax.plot(
+ torch.arange(args.D_max + 1),
+ mse_test,
+ color=(1.0, e, e),
+ label=f"Test N={nb_train_samples}",
+ )
+
+ax.legend(frameon=False)
+
+fig.savefig("dd-multi-mse.pdf", bbox_inches="tight")
+
+plt.close(fig)