ax.plot(torch.arange(args.D_max + 1), mse_train, color = (e, e, 1.0), label = f'Train N={nb_train_samples}')
ax.plot(torch.arange(args.D_max + 1), mse_test, color = (1.0, e, e), label = f'Test N={nb_train_samples}')
+ax.legend(frameon = False)
+
fig.savefig('dd-multi-mse.pdf', bbox_inches='tight')
plt.close(fig)