Update.
[pytorch] / ddpol.py
index 97d2ff5..db587fa 100755 (executable)
--- a/ddpol.py
+++ b/ddpol.py
@@ -9,14 +9,23 @@ import math
 import matplotlib.pyplot as plt
 import torch
 
+nb_train_samples = 8
+D_max = 16
+nb_runs = 250
+train_noise_std = 0
+
 ######################################################################
 
-def compute_alpha(x, y, D, a = 0, b = 1, rho = 1e-11):
+def pol_value(alpha, x):
+    x_pow = x.view(-1, 1) ** torch.arange(alpha.size(0)).view(1, -1)
+    return x_pow @ alpha
+
+def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12):
     M = x.view(-1, 1) ** torch.arange(D + 1).view(1, -1)
     B = y
 
-    if D+1 > 2:
-        q = torch.arange(2, D + 1).view( 1, -1).to(x.dtype)
+    if D >= 2:
+        q = torch.arange(2, D + 1, dtype = x.dtype).view(1, -1)
         r = q.view(-1,  1)
         beta = x.new_zeros(D + 1, D + 1)
         beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3)
@@ -25,50 +34,69 @@ def compute_alpha(x, y, D, a = 0, b = 1, rho = 1e-11):
         B = torch.cat((B, y.new_zeros(Q.size(0))), 0)
         M = torch.cat((M, math.sqrt(rho) * Q.t()), 0)
 
-    alpha = torch.lstsq(B, M).solution.view(-1)[:D+1]
-
-    return alpha
+    return torch.lstsq(B, M).solution.view(-1)[:D+1]
 
 ######################################################################
 
 def phi(x):
-    return 4 * (x - 0.5) ** 2 * (x >= 0.5)
+    # return 4 * (x - 0.6) ** 2 * (x >= 0.6) - 4 * (x - 0.4) ** 2 * (x <= 0.4) + 0.5
+    # return 4 * (x - 0.5) ** 2 * (x >= 0.5)
+    return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1
+    # return x/2 - torch.sign(x-0.4) * 0.3
 
 ######################################################################
 
 torch.manual_seed(0)
 
-nb_train_samples = 7
-D_max = 16
-nb_runs = 250
-
 mse_train = torch.zeros(nb_runs, D_max + 1)
 mse_test = torch.zeros(nb_runs, D_max + 1)
 
 for k in range(nb_runs):
     x_train = torch.rand(nb_train_samples, dtype = torch.float64)
+    # x_train = torch.linspace(0, 1, nb_train_samples, dtype = torch.float64)
     y_train = phi(x_train)
-    y_train = y_train + torch.empty(y_train.size(), dtype = y_train.dtype).normal_(0, 0.1)
+    if train_noise_std > 0:
+        y_train = y_train + torch.empty_like(y_train).normal_(0, train_noise_std)
     x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype)
     y_test = phi(x_test)
 
     for D in range(D_max + 1):
-        alpha = compute_alpha(x_train, y_train, D)
-        X_train = x_train.view(-1, 1) ** torch.arange(D + 1).view(1, -1)
-        X_test = x_test.view(-1, 1) ** torch.arange(D + 1).view(1, -1)
-        mse_train[k, D] = ((X_train @ alpha - y_train)**2).mean()
-        mse_test[k, D] = ((X_test @ alpha - y_test)**2).mean()
+        alpha = fit_alpha(x_train, y_train, D)
+        mse_train[k, D] = ((pol_value(alpha, x_train) - y_train)**2).mean()
+        mse_test[k, D] = ((pol_value(alpha, x_test) - y_test)**2).mean()
 
 mse_train = mse_train.median(0).values
 mse_test = mse_test.median(0).values
 
 ######################################################################
+# Plot the MSE vs. degree curves
+
+fig = plt.figure()
+
+ax = fig.add_subplot(1, 1, 1)
+ax.set_yscale('log')
+ax.set_ylim(1e-5, 1)
+ax.set_xlabel('Polynomial degree', labelpad = 10)
+ax.set_ylabel('MSE', labelpad = 10)
+
+ax.axvline(x = nb_train_samples - 1, color = 'gray', linewidth = 0.5)
+ax.plot(torch.arange(D_max + 1), mse_train, color = 'blue', label = 'Train error')
+ax.plot(torch.arange(D_max + 1), mse_test, color = 'red', label = 'Test error')
+
+ax.legend(frameon = False)
+
+fig.savefig('dd-mse.pdf', bbox_inches='tight')
+
+######################################################################
+# Plot some examples of train / test
 
-torch.manual_seed(4) # I picked that for pretty
+torch.manual_seed(5) # I picked that for pretty
 
 x_train = torch.rand(nb_train_samples, dtype = torch.float64)
+# x_train = torch.linspace(0, 1, nb_train_samples, dtype = torch.float64)
 y_train = phi(x_train)
-y_train = y_train + torch.empty(y_train.size(), dtype = y_train.dtype).normal_(0, 0.1)
+if train_noise_std > 0:
+    y_train = y_train + torch.empty_like(y_train).normal_(0, train_noise_std)
 x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype)
 y_test = phi(x_test)
 
@@ -81,29 +109,11 @@ for D in range(D_max + 1):
     ax.plot(x_test, y_test, color = 'blue', label = 'Test values')
     ax.scatter(x_train, y_train, color = 'blue', label = 'Training examples')
 
-    alpha = compute_alpha(x_train, y_train, D)
-    X_test = x_test.view(-1, 1) ** torch.arange(D + 1).view(1, -1)
-    ax.plot(x_test, X_test @ alpha, color = 'red', label = 'Fitted polynomial')
+    alpha = fit_alpha(x_train, y_train, D)
+    ax.plot(x_test, pol_value(alpha, x_test), color = 'red', label = 'Fitted polynomial')
 
     ax.legend(frameon = False)
 
     fig.savefig(f'dd-example-{D:02d}.pdf', bbox_inches='tight')
 
 ######################################################################
-
-fig = plt.figure()
-
-ax = fig.add_subplot(1, 1, 1)
-ax.set_yscale('log')
-ax.set_xlabel('Polynomial degree', labelpad = 10)
-ax.set_ylabel('MSE', labelpad = 10)
-
-ax.axvline(x = nb_train_samples - 1, color = 'gray', linewidth = 0.5)
-ax.plot(torch.arange(D_max + 1), mse_train, color = 'blue', label = 'Train error')
-ax.plot(torch.arange(D_max + 1), mse_test, color = 'red', label = 'Test error')
-
-ax.legend(frameon = False)
-
-fig.savefig('dd-mse.pdf', bbox_inches='tight')
-
-######################################################################