From: Francois Fleuret Date: Mon, 22 Jun 2020 07:59:39 +0000 (+0200) Subject: OCD update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=f8719b07be90426c72d5bcb8ae52c9482642e365;p=pytorch.git OCD update. --- diff --git a/ddpol.py b/ddpol.py index 51e7636..35d98a0 100755 --- a/ddpol.py +++ b/ddpol.py @@ -83,12 +83,6 @@ def compute_mse(nb_train_samples): return mse_train.median(0).values, mse_test.median(0).values -###################################################################### - -torch.manual_seed(0) - -mse_train, mse_test = compute_mse(args.nb_train_samples) - ###################################################################### # Plot the MSE vs. degree curves @@ -100,7 +94,14 @@ ax.set_ylim(1e-5, 1) ax.set_xlabel('Polynomial degree', labelpad = 10) ax.set_ylabel('MSE', labelpad = 10) -ax.axvline(x = args.nb_train_samples - 1, color = 'gray', linewidth = 0.5) +ax.axvline(x = args.nb_train_samples - 1, + color = 'gray', linewidth = 0.5, linestyle = '--') +ax.text(args.nb_train_samples - 1.2, 1e-4, 'Nb. params = nb. samples', + fontsize = 10, color = 'gray', + rotation = 90, rotation_mode='anchor') + +mse_train, mse_test = compute_mse(args.nb_train_samples) + ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train error') ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test error')