import torch, torchvision
from torch import nn
-lr, nb_epochs, batch_size = 2e-3, 50, 100
+lr, nb_epochs, batch_size = 2e-3, 100, 100
data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/"
######################################################################
-for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
- for linear_layer in [nn.Linear, QLinear]:
+errors = {QLinear: [], nn.Linear: []}
+
+for linear_layer in errors.keys():
+ for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
# The model
model = nn.Sequential(
######################################################################
- print(
- f"final_loss {nb_hidden} {linear_layer} {acc_train_loss/train_input.size(0)} {test_error*100} %"
- )
- sys.stdout.flush()
+ errors[linear_layer].append((nb_hidden, test_error))
+
+import matplotlib.pyplot as plt
+
+fig = plt.figure()
+fig.set_figheight(6)
+fig.set_figwidth(8)
+
+ax = fig.add_subplot(1, 1, 1)
+
+ax.set_ylim(0, 1)
+ax.spines.right.set_visible(False)
+ax.spines.top.set_visible(False)
+ax.set_xscale("log")
+ax.set_xlabel("Nb hidden units")
+ax.set_ylabel("Test error (%)")
+
+X = torch.tensor([x[0] for x in errors[nn.Linear]])
+Y = torch.tensor([x[1] for x in errors[nn.Linear]])
+ax.plot(X, Y, color="gray", label="nn.Linear")
+
+X = torch.tensor([x[0] for x in errors[QLinear]])
+Y = torch.tensor([x[1] for x in errors[QLinear]])
+ax.plot(X, Y, color="red", label="QLinear")
+
+ax.legend(frameon=False, loc=1)
+
+filename = f"bit_mlp.pdf"
+print(f"saving {filename}")
+fig.savefig(filename, bbox_inches="tight")