From 21f814e17a7022e3332ea1f8ff6fc43c769e7e92 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Sat, 6 Jun 2020 17:43:36 +0200 Subject: [PATCH] Update. --- ddpol.py | 88 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 49 insertions(+), 39 deletions(-) diff --git a/ddpol.py b/ddpol.py index 97d2ff5..db587fa 100755 --- a/ddpol.py +++ b/ddpol.py @@ -9,14 +9,23 @@ import math import matplotlib.pyplot as plt import torch +nb_train_samples = 8 +D_max = 16 +nb_runs = 250 +train_noise_std = 0 + ###################################################################### -def compute_alpha(x, y, D, a = 0, b = 1, rho = 1e-11): +def pol_value(alpha, x): + x_pow = x.view(-1, 1) ** torch.arange(alpha.size(0)).view(1, -1) + return x_pow @ alpha + +def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12): M = x.view(-1, 1) ** torch.arange(D + 1).view(1, -1) B = y - if D+1 > 2: - q = torch.arange(2, D + 1).view( 1, -1).to(x.dtype) + if D >= 2: + q = torch.arange(2, D + 1, dtype = x.dtype).view(1, -1) r = q.view(-1, 1) beta = x.new_zeros(D + 1, D + 1) beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3) @@ -25,50 +34,69 @@ def compute_alpha(x, y, D, a = 0, b = 1, rho = 1e-11): B = torch.cat((B, y.new_zeros(Q.size(0))), 0) M = torch.cat((M, math.sqrt(rho) * Q.t()), 0) - alpha = torch.lstsq(B, M).solution.view(-1)[:D+1] - - return alpha + return torch.lstsq(B, M).solution.view(-1)[:D+1] ###################################################################### def phi(x): - return 4 * (x - 0.5) ** 2 * (x >= 0.5) + # return 4 * (x - 0.6) ** 2 * (x >= 0.6) - 4 * (x - 0.4) ** 2 * (x <= 0.4) + 0.5 + # return 4 * (x - 0.5) ** 2 * (x >= 0.5) + return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1 + # return x/2 - torch.sign(x-0.4) * 0.3 ###################################################################### torch.manual_seed(0) -nb_train_samples = 7 -D_max = 16 -nb_runs = 250 - mse_train = torch.zeros(nb_runs, D_max + 1) mse_test = torch.zeros(nb_runs, D_max + 1) for k in range(nb_runs): x_train = torch.rand(nb_train_samples, dtype = torch.float64) + # x_train = torch.linspace(0, 1, nb_train_samples, dtype = torch.float64) y_train = phi(x_train) - y_train = y_train + torch.empty(y_train.size(), dtype = y_train.dtype).normal_(0, 0.1) + if train_noise_std > 0: + y_train = y_train + torch.empty_like(y_train).normal_(0, train_noise_std) x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype) y_test = phi(x_test) for D in range(D_max + 1): - alpha = compute_alpha(x_train, y_train, D) - X_train = x_train.view(-1, 1) ** torch.arange(D + 1).view(1, -1) - X_test = x_test.view(-1, 1) ** torch.arange(D + 1).view(1, -1) - mse_train[k, D] = ((X_train @ alpha - y_train)**2).mean() - mse_test[k, D] = ((X_test @ alpha - y_test)**2).mean() + alpha = fit_alpha(x_train, y_train, D) + mse_train[k, D] = ((pol_value(alpha, x_train) - y_train)**2).mean() + mse_test[k, D] = ((pol_value(alpha, x_test) - y_test)**2).mean() mse_train = mse_train.median(0).values mse_test = mse_test.median(0).values ###################################################################### +# Plot the MSE vs. degree curves + +fig = plt.figure() + +ax = fig.add_subplot(1, 1, 1) +ax.set_yscale('log') +ax.set_ylim(1e-5, 1) +ax.set_xlabel('Polynomial degree', labelpad = 10) +ax.set_ylabel('MSE', labelpad = 10) + +ax.axvline(x = nb_train_samples - 1, color = 'gray', linewidth = 0.5) +ax.plot(torch.arange(D_max + 1), mse_train, color = 'blue', label = 'Train error') +ax.plot(torch.arange(D_max + 1), mse_test, color = 'red', label = 'Test error') + +ax.legend(frameon = False) + +fig.savefig('dd-mse.pdf', bbox_inches='tight') + +###################################################################### +# Plot some examples of train / test -torch.manual_seed(4) # I picked that for pretty +torch.manual_seed(5) # I picked that for pretty x_train = torch.rand(nb_train_samples, dtype = torch.float64) +# x_train = torch.linspace(0, 1, nb_train_samples, dtype = torch.float64) y_train = phi(x_train) -y_train = y_train + torch.empty(y_train.size(), dtype = y_train.dtype).normal_(0, 0.1) +if train_noise_std > 0: + y_train = y_train + torch.empty_like(y_train).normal_(0, train_noise_std) x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype) y_test = phi(x_test) @@ -81,29 +109,11 @@ for D in range(D_max + 1): ax.plot(x_test, y_test, color = 'blue', label = 'Test values') ax.scatter(x_train, y_train, color = 'blue', label = 'Training examples') - alpha = compute_alpha(x_train, y_train, D) - X_test = x_test.view(-1, 1) ** torch.arange(D + 1).view(1, -1) - ax.plot(x_test, X_test @ alpha, color = 'red', label = 'Fitted polynomial') + alpha = fit_alpha(x_train, y_train, D) + ax.plot(x_test, pol_value(alpha, x_test), color = 'red', label = 'Fitted polynomial') ax.legend(frameon = False) fig.savefig(f'dd-example-{D:02d}.pdf', bbox_inches='tight') ###################################################################### - -fig = plt.figure() - -ax = fig.add_subplot(1, 1, 1) -ax.set_yscale('log') -ax.set_xlabel('Polynomial degree', labelpad = 10) -ax.set_ylabel('MSE', labelpad = 10) - -ax.axvline(x = nb_train_samples - 1, color = 'gray', linewidth = 0.5) -ax.plot(torch.arange(D_max + 1), mse_train, color = 'blue', label = 'Train error') -ax.plot(torch.arange(D_max + 1), mse_test, color = 'red', label = 'Test error') - -ax.legend(frameon = False) - -fig.savefig('dd-mse.pdf', bbox_inches='tight') - -###################################################################### -- 2.39.5