-ax.axvline(x = args.nb_train_samples - 1, color = 'gray', linewidth = 0.5)
-ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train error')
-ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test error')
+ax.axvline(x = args.nb_train_samples - 1,
+ color = 'gray', linewidth = 0.5, linestyle = '--')
+
+ax.text(args.nb_train_samples - 1.2, 1e-4, 'Nb. params = nb. samples',
+ fontsize = 10, color = 'gray',
+ rotation = 90, rotation_mode='anchor')
+
+mse_train, mse_test = compute_mse(args.nb_train_samples)
+ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train')
+ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test')