ax.plot(torch.arange(args.D_max + 1), mse_train, color = (e, e, 1.0), label = f'Train N={nb_train_samples}')
ax.plot(torch.arange(args.D_max + 1), mse_test, color = (1.0, e, e), label = f'Test N={nb_train_samples}')
ax.plot(torch.arange(args.D_max + 1), mse_train, color = (e, e, 1.0), label = f'Train N={nb_train_samples}')
ax.plot(torch.arange(args.D_max + 1), mse_test, color = (1.0, e, e), label = f'Test N={nb_train_samples}')