# Written by Francois Fleuret <francois@fleuret.org>
-import math
+import math, argparse
import matplotlib.pyplot as plt
+
import torch
-nb_train_samples = 8
-D_max = 16
-nb_runs = 250
-train_noise_std = 0
+######################################################################
+
+parser = argparse.ArgumentParser(description='Example of double descent with polynomial regression.')
+
+parser.add_argument('--D-max',
+ type = int, default = 16)
+
+parser.add_argument('--nb-runs',
+ type = int, default = 250)
+
+parser.add_argument('--nb-train-samples',
+ type = int, default = 8)
+
+parser.add_argument('--train-noise-std',
+ type = float, default = 0.)
+
+parser.add_argument('--seed',
+ type = int, default = 0,
+ help = 'Random seed (default 0, < 0 is no seeding)')
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+ torch.manual_seed(args.seed)
######################################################################
beta = x.new_zeros(D + 1, D + 1)
beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3)
l, U = beta.eig(eigenvectors = True)
- Q = U @ torch.diag(l[:, 0].pow(0.5))
+ Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5)
B = torch.cat((B, y.new_zeros(Q.size(0))), 0)
M = torch.cat((M, math.sqrt(rho) * Q.t()), 0)
######################################################################
-torch.manual_seed(0)
+def compute_mse(nb_train_samples):
+ mse_train = torch.zeros(args.nb_runs, args.D_max + 1)
+ mse_test = torch.zeros(args.nb_runs, args.D_max + 1)
+
+ for k in range(args.nb_runs):
+ x_train = torch.rand(nb_train_samples, dtype = torch.float64)
+ y_train = phi(x_train)
+ if args.train_noise_std > 0:
+ y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std)
+ x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype)
+ y_test = phi(x_test)
+
+ for D in range(args.D_max + 1):
+ alpha = fit_alpha(x_train, y_train, D)
+ mse_train[k, D] = ((pol_value(alpha, x_train) - y_train)**2).mean()
+ mse_test[k, D] = ((pol_value(alpha, x_test) - y_test)**2).mean()
-mse_train = torch.zeros(nb_runs, D_max + 1)
-mse_test = torch.zeros(nb_runs, D_max + 1)
+ return mse_train.median(0).values, mse_test.median(0).values
-for k in range(nb_runs):
- x_train = torch.rand(nb_train_samples, dtype = torch.float64)
- y_train = phi(x_train)
- if train_noise_std > 0:
- y_train = y_train + torch.empty_like(y_train).normal_(0, train_noise_std)
- x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype)
- y_test = phi(x_test)
+######################################################################
- for D in range(D_max + 1):
- alpha = fit_alpha(x_train, y_train, D)
- mse_train[k, D] = ((pol_value(alpha, x_train) - y_train)**2).mean()
- mse_test[k, D] = ((pol_value(alpha, x_test) - y_test)**2).mean()
+torch.manual_seed(0)
-mse_train = mse_train.median(0).values
-mse_test = mse_test.median(0).values
+mse_train, mse_test = compute_mse(args.nb_train_samples)
######################################################################
# Plot the MSE vs. degree curves
ax.set_xlabel('Polynomial degree', labelpad = 10)
ax.set_ylabel('MSE', labelpad = 10)
-ax.axvline(x = nb_train_samples - 1, color = 'gray', linewidth = 0.5)
-ax.plot(torch.arange(D_max + 1), mse_train, color = 'blue', label = 'Train error')
-ax.plot(torch.arange(D_max + 1), mse_test, color = 'red', label = 'Test error')
+ax.axvline(x = args.nb_train_samples - 1, color = 'gray', linewidth = 0.5)
+ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train error')
+ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test error')
ax.legend(frameon = False)
fig.savefig('dd-mse.pdf', bbox_inches='tight')
+plt.close(fig)
+
######################################################################
# Plot some examples of train / test
torch.manual_seed(9) # I picked that for pretty
-x_train = torch.rand(nb_train_samples, dtype = torch.float64)
+x_train = torch.rand(args.nb_train_samples, dtype = torch.float64)
y_train = phi(x_train)
-if train_noise_std > 0:
- y_train = y_train + torch.empty_like(y_train).normal_(0, train_noise_std)
+if args.train_noise_std > 0:
+ y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std)
x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype)
y_test = phi(x_test)
-for D in range(D_max + 1):
+for D in range(args.D_max + 1):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
fig.savefig(f'dd-example-{D:02d}.pdf', bbox_inches='tight')
+ plt.close(fig)
+
######################################################################