X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=bit_mlp.py;h=6f7f92e6172ee07b20cdf2cc89410a95ba6ab165;hb=HEAD;hp=85262b72f28f5c6b0cacd6ea8bb86d6d6967cd55;hpb=53d7745e661073bad93752ea41b0320312250954;p=pytorch.git diff --git a/bit_mlp.py b/bit_mlp.py index 85262b7..6f7f92e 100755 --- a/bit_mlp.py +++ b/bit_mlp.py @@ -9,7 +9,7 @@ import os, sys import torch, torchvision from torch import nn -lr, nb_epochs, batch_size = 2e-3, 50, 100 +lr, nb_epochs, batch_size = 2e-3, 100, 100 data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/" @@ -57,8 +57,12 @@ class QLinear(nn.Module): ###################################################################### -for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: - for linear_layer in [nn.Linear, QLinear]: +errors = {QLinear: [], nn.Linear: []} + +for linear_layer in errors.keys(): + for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: + # The model + model = nn.Sequential( nn.Flatten(), linear_layer(784, nb_hidden), @@ -72,10 +76,9 @@ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: optimizer = torch.optim.Adam(model.parameters(), lr=lr) - ###################################################################### + # for k in range(nb_epochs): - ############################################ # Train model.train() @@ -93,7 +96,6 @@ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: loss.backward() optimizer.step() - ############################################ # Test model.eval() @@ -114,7 +116,40 @@ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: ###################################################################### - print( - f"final_loss {nb_hidden} {linear_layer} {acc_train_loss/train_input.size(0)} {test_error*100} %" + errors[linear_layer].append( + (nb_hidden, test_error * 100, acc_train_loss / train_input.size(0)) ) - sys.stdout.flush() + +import matplotlib.pyplot as plt + + +def save_fig(filename, ymax, ylabel, index): + fig = plt.figure() + fig.set_figheight(6) + fig.set_figwidth(8) + + ax = fig.add_subplot(1, 1, 1) + + ax.set_ylim(0, ymax) + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) + ax.set_xscale("log") + ax.set_xlabel("Nb hidden units") + ax.set_ylabel(ylabel) + + X = torch.tensor([x[0] for x in errors[nn.Linear]]) + Y = torch.tensor([x[index] for x in errors[nn.Linear]]) + ax.plot(X, Y, color="gray", label="nn.Linear") + + X = torch.tensor([x[0] for x in errors[QLinear]]) + Y = torch.tensor([x[index] for x in errors[QLinear]]) + ax.plot(X, Y, color="red", label="QLinear") + + ax.legend(frameon=False, loc=1) + + print(f"saving {filename}") + fig.savefig(filename, bbox_inches="tight") + + +save_fig("bit_mlp_err.pdf", ymax=15, ylabel="Test error (%)", index=1) +save_fig("bit_mlp_loss.pdf", ymax=1.25, ylabel="Train loss", index=2)