3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 import matplotlib.pyplot as plt
13 ######################################################################
15 parser = argparse.ArgumentParser(description='Example of double descent with polynomial regression.')
17 parser.add_argument('--D-max',
18 type = int, default = 16)
20 parser.add_argument('--nb-runs',
21 type = int, default = 250)
23 parser.add_argument('--nb-train-samples',
24 type = int, default = 8)
26 parser.add_argument('--train-noise-std',
27 type = float, default = 0.)
29 parser.add_argument('--seed',
30 type = int, default = 0,
31 help = 'Random seed (default 0, < 0 is no seeding)')
33 args = parser.parse_args()
36 torch.manual_seed(args.seed)
38 ######################################################################
40 def pol_value(alpha, x):
41 x_pow = x.view(-1, 1) ** torch.arange(alpha.size(0)).view(1, -1)
44 def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12):
45 M = x.view(-1, 1) ** torch.arange(D + 1).view(1, -1)
49 q = torch.arange(2, D + 1, dtype = x.dtype).view(1, -1)
51 beta = x.new_zeros(D + 1, D + 1)
52 beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3)
53 l, U = beta.eig(eigenvectors = True)
54 Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5)
55 B = torch.cat((B, y.new_zeros(Q.size(0))), 0)
56 M = torch.cat((M, math.sqrt(rho) * Q.t()), 0)
58 return torch.lstsq(B, M).solution[:D+1, 0]
60 ######################################################################
63 return torch.abs(torch.abs(x - 0.4) - 0.2) + x/2 - 0.1
65 ######################################################################
67 def compute_mse(nb_train_samples):
68 mse_train = torch.zeros(args.nb_runs, args.D_max + 1)
69 mse_test = torch.zeros(args.nb_runs, args.D_max + 1)
71 for k in range(args.nb_runs):
72 x_train = torch.rand(nb_train_samples, dtype = torch.float64)
73 y_train = phi(x_train)
74 if args.train_noise_std > 0:
75 y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std)
76 x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype)
79 for D in range(args.D_max + 1):
80 alpha = fit_alpha(x_train, y_train, D)
81 mse_train[k, D] = ((pol_value(alpha, x_train) - y_train)**2).mean()
82 mse_test[k, D] = ((pol_value(alpha, x_test) - y_test)**2).mean()
84 return mse_train.median(0).values, mse_test.median(0).values
86 ######################################################################
90 mse_train, mse_test = compute_mse(args.nb_train_samples)
92 ######################################################################
93 # Plot the MSE vs. degree curves
97 ax = fig.add_subplot(1, 1, 1)
100 ax.set_xlabel('Polynomial degree', labelpad = 10)
101 ax.set_ylabel('MSE', labelpad = 10)
103 ax.axvline(x = args.nb_train_samples - 1, color = 'gray', linewidth = 0.5)
104 ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train error')
105 ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test error')
107 ax.legend(frameon = False)
109 fig.savefig('dd-mse.pdf', bbox_inches='tight')
113 ######################################################################
114 # Plot some examples of train / test
116 torch.manual_seed(9) # I picked that for pretty
118 x_train = torch.rand(args.nb_train_samples, dtype = torch.float64)
119 y_train = phi(x_train)
120 if args.train_noise_std > 0:
121 y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std)
122 x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype)
125 for D in range(args.D_max + 1):
128 ax = fig.add_subplot(1, 1, 1)
129 ax.set_title(f'Degree {D}')
130 ax.set_ylim(-0.1, 1.1)
131 ax.plot(x_test, y_test, color = 'black', label = 'Test values')
132 ax.scatter(x_train, y_train, color = 'blue', label = 'Train samples')
134 alpha = fit_alpha(x_train, y_train, D)
135 ax.plot(x_test, pol_value(alpha, x_test), color = 'red', label = 'Fitted polynomial')
137 ax.legend(frameon = False)
139 fig.savefig(f'dd-example-{D:02d}.pdf', bbox_inches='tight')
143 ######################################################################