beta = x.new_zeros(D + 1, D + 1)
beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3)
l, U = beta.eig(eigenvectors = True)
- Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5)
+ Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5) # clamp deals with ~0 negative values
B = torch.cat((B, y.new_zeros(Q.size(0))), 0)
M = torch.cat((M, math.sqrt(rho) * Q.t()), 0)
######################################################################
+# The "ground truth"
+
def phi(x):
return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1
ax.axvline(x = args.nb_train_samples - 1,
color = 'gray', linewidth = 0.5, linestyle = '--')
+
ax.text(args.nb_train_samples - 1.2, 1e-4, 'Nb. params = nb. samples',
fontsize = 10, color = 'gray',
rotation = 90, rotation_mode='anchor')
mse_train, mse_test = compute_mse(args.nb_train_samples)
-
-ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train error')
-ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test error')
+ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train')
+ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test')
ax.legend(frameon = False)
plt.close(fig)
+######################################################################
+# Plot multiple MSE vs. degree curves
+
+fig = plt.figure()
+
+ax = fig.add_subplot(1, 1, 1)
+ax.set_yscale('log')
+ax.set_ylim(1e-5, 1)
+ax.set_xlabel('Polynomial degree', labelpad = 10)
+ax.set_ylabel('MSE', labelpad = 10)
+
+nb_train_samples_min = args.nb_train_samples - 4
+nb_train_samples_max = args.nb_train_samples
+
+for nb_train_samples in range(nb_train_samples_min, nb_train_samples_max + 1, 2):
+ mse_train, mse_test = compute_mse(nb_train_samples)
+ e = float(nb_train_samples - nb_train_samples_min) / float(nb_train_samples_max - nb_train_samples_min)
+ e = 0.15 + 0.7 * e
+ ax.plot(torch.arange(args.D_max + 1), mse_train, color = (e, e, 1.0), label = f'Train N={nb_train_samples}')
+ ax.plot(torch.arange(args.D_max + 1), mse_test, color = (1.0, e, e), label = f'Test N={nb_train_samples}')
+
+ax.legend(frameon = False)
+
+fig.savefig('dd-multi-mse.pdf', bbox_inches='tight')
+
+plt.close(fig)
+
######################################################################
# Plot some examples of train / test