From 329837be2f41d7839046cc5ab0825b824825bf84 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Sat, 6 Jun 2020 17:58:09 +0200 Subject: [PATCH] Update. --- ddpol.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/ddpol.py b/ddpol.py index db587fa..9a1bbc9 100755 --- a/ddpol.py +++ b/ddpol.py @@ -39,10 +39,7 @@ def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12): ###################################################################### def phi(x): - # return 4 * (x - 0.6) ** 2 * (x >= 0.6) - 4 * (x - 0.4) ** 2 * (x <= 0.4) + 0.5 - # return 4 * (x - 0.5) ** 2 * (x >= 0.5) return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1 - # return x/2 - torch.sign(x-0.4) * 0.3 ###################################################################### @@ -53,7 +50,6 @@ mse_test = torch.zeros(nb_runs, D_max + 1) for k in range(nb_runs): x_train = torch.rand(nb_train_samples, dtype = torch.float64) - # x_train = torch.linspace(0, 1, nb_train_samples, dtype = torch.float64) y_train = phi(x_train) if train_noise_std > 0: y_train = y_train + torch.empty_like(y_train).normal_(0, train_noise_std) @@ -90,10 +86,9 @@ fig.savefig('dd-mse.pdf', bbox_inches='tight') ###################################################################### # Plot some examples of train / test -torch.manual_seed(5) # I picked that for pretty +torch.manual_seed(9) # I picked that for pretty x_train = torch.rand(nb_train_samples, dtype = torch.float64) -# x_train = torch.linspace(0, 1, nb_train_samples, dtype = torch.float64) y_train = phi(x_train) if train_noise_std > 0: y_train = y_train + torch.empty_like(y_train).normal_(0, train_noise_std) @@ -106,8 +101,8 @@ for D in range(D_max + 1): ax = fig.add_subplot(1, 1, 1) ax.set_title(f'Degree {D}') ax.set_ylim(-0.1, 1.1) - ax.plot(x_test, y_test, color = 'blue', label = 'Test values') - ax.scatter(x_train, y_train, color = 'blue', label = 'Training examples') + ax.plot(x_test, y_test, color = 'black', label = 'Test values') + ax.scatter(x_train, y_train, color = 'blue', label = 'Train samples') alpha = fit_alpha(x_train, y_train, D) ax.plot(x_test, pol_value(alpha, x_test), color = 'red', label = 'Fitted polynomial') -- 2.39.5