From 1660da6ae12f49c59876ff4be2f1cded4c8e4d1e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 26 Mar 2024 18:58:56 +0100 Subject: [PATCH] Update. --- bit_mlp.py | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/bit_mlp.py b/bit_mlp.py index 8fffe7a..90409f2 100755 --- a/bit_mlp.py +++ b/bit_mlp.py @@ -9,7 +9,7 @@ import os, sys import torch, torchvision from torch import nn -lr, nb_epochs, batch_size = 2e-3, 50, 100 +lr, nb_epochs, batch_size = 2e-3, 100, 100 data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/" @@ -57,8 +57,10 @@ class QLinear(nn.Module): ###################################################################### -for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: - for linear_layer in [nn.Linear, QLinear]: +errors = {QLinear: [], nn.Linear: []} + +for linear_layer in errors.keys(): + for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: # The model model = nn.Sequential( @@ -114,7 +116,33 @@ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: ###################################################################### - print( - f"final_loss {nb_hidden} {linear_layer} {acc_train_loss/train_input.size(0)} {test_error*100} %" - ) - sys.stdout.flush() + errors[linear_layer].append((nb_hidden, test_error)) + +import matplotlib.pyplot as plt + +fig = plt.figure() +fig.set_figheight(6) +fig.set_figwidth(8) + +ax = fig.add_subplot(1, 1, 1) + +ax.set_ylim(0, 1) +ax.spines.right.set_visible(False) +ax.spines.top.set_visible(False) +ax.set_xscale("log") +ax.set_xlabel("Nb hidden units") +ax.set_ylabel("Test error (%)") + +X = torch.tensor([x[0] for x in errors[nn.Linear]]) +Y = torch.tensor([x[1] for x in errors[nn.Linear]]) +ax.plot(X, Y, color="gray", label="nn.Linear") + +X = torch.tensor([x[0] for x in errors[QLinear]]) +Y = torch.tensor([x[1] for x in errors[QLinear]]) +ax.plot(X, Y, color="red", label="QLinear") + +ax.legend(frameon=False, loc=1) + +filename = f"bit_mlp.pdf" +print(f"saving {filename}") +fig.savefig(filename, bbox_inches="tight") -- 2.39.5