3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
9 import torch, torchvision
12 lr, nb_epochs, batch_size = 2e-3, 50, 100
14 data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/"
16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18 ######################################################################
20 train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True)
21 train_input = train_set.data.view(-1, 1, 28, 28).float()
22 train_targets = train_set.targets
24 test_set = torchvision.datasets.MNIST(root=data_dir, train=False, download=True)
25 test_input = test_set.data.view(-1, 1, 28, 28).float()
26 test_targets = test_set.targets
28 train_input, train_targets = train_input.to(device), train_targets.to(device)
29 test_input, test_targets = test_input.to(device), test_targets.to(device)
31 mu, std = train_input.mean(), train_input.std()
33 train_input.sub_(mu).div_(std)
34 test_input.sub_(mu).div_(std)
36 ######################################################################
39 class QLinear(nn.Module):
40 def __init__(self, dim_in, dim_out):
42 self.w = nn.Parameter(torch.randn(dim_out, dim_in))
43 self.b = nn.Parameter(torch.randn(dim_out) * 1e-1)
45 def quantize(self, z):
47 zr = z / (z.abs().mean() + epsilon)
48 zq = -(zr <= -0.5).long() + (zr >= 0.5).long()
50 return zq + z - z.detach()
55 return x @ self.quantize(self.w).t() + self.quantize(self.b)
58 ######################################################################
60 for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
61 for linear_layer in [nn.Linear, QLinear]:
62 model = nn.Sequential(
64 linear_layer(784, nb_hidden),
66 linear_layer(nb_hidden, 10),
69 nb_parameters = sum(p.numel() for p in model.parameters())
71 print(f"nb_parameters {nb_parameters}")
73 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
75 ######################################################################
77 for k in range(nb_epochs):
78 ############################################
85 for input, targets in zip(
86 train_input.split(batch_size), train_targets.split(batch_size)
89 loss = torch.nn.functional.cross_entropy(output, targets)
90 acc_train_loss += loss.item() * input.size(0)
96 ############################################
102 for input, targets in zip(
103 test_input.split(batch_size), test_targets.split(batch_size)
105 wta = model(input).argmax(1)
106 nb_test_errors += (wta != targets).long().sum()
107 test_error = nb_test_errors / test_input.size(0)
109 if (k + 1) % 10 == 0:
111 f"loss {k+1} {acc_train_loss/train_input.size(0)} {test_error*100:.02f}%"
115 ######################################################################
118 f"final_loss {nb_hidden} {linear_layer} {acc_train_loss/train_input.size(0)} {test_error*100} %"