Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 26 Mar 2024 17:44:06 +0000 (18:44 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 26 Mar 2024 17:44:06 +0000 (18:44 +0100)
bit_mlp.py [new file with mode: 0755]

diff --git a/bit_mlp.py b/bit_mlp.py
new file mode 100755 (executable)
index 0000000..85262b7
--- /dev/null
@@ -0,0 +1,120 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import os, sys
+import torch, torchvision
+from torch import nn
+
+lr, nb_epochs, batch_size = 2e-3, 50, 100
+
+data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/"
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+######################################################################
+
+train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True)
+train_input = train_set.data.view(-1, 1, 28, 28).float()
+train_targets = train_set.targets
+
+test_set = torchvision.datasets.MNIST(root=data_dir, train=False, download=True)
+test_input = test_set.data.view(-1, 1, 28, 28).float()
+test_targets = test_set.targets
+
+train_input, train_targets = train_input.to(device), train_targets.to(device)
+test_input, test_targets = test_input.to(device), test_targets.to(device)
+
+mu, std = train_input.mean(), train_input.std()
+
+train_input.sub_(mu).div_(std)
+test_input.sub_(mu).div_(std)
+
+######################################################################
+
+
+class QLinear(nn.Module):
+    def __init__(self, dim_in, dim_out):
+        super().__init__()
+        self.w = nn.Parameter(torch.randn(dim_out, dim_in))
+        self.b = nn.Parameter(torch.randn(dim_out) * 1e-1)
+
+    def quantize(self, z):
+        epsilon = 1e-3
+        zr = z / (z.abs().mean() + epsilon)
+        zq = -(zr <= -0.5).long() + (zr >= 0.5).long()
+        if self.training:
+            return zq + z - z.detach()
+        else:
+            return zq.float()
+
+    def forward(self, x):
+        return x @ self.quantize(self.w).t() + self.quantize(self.b)
+
+
+######################################################################
+
+for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]:
+    for linear_layer in [nn.Linear, QLinear]:
+        model = nn.Sequential(
+            nn.Flatten(),
+            linear_layer(784, nb_hidden),
+            nn.ReLU(),
+            linear_layer(nb_hidden, 10),
+        ).to(device)
+
+        nb_parameters = sum(p.numel() for p in model.parameters())
+
+        print(f"nb_parameters {nb_parameters}")
+
+        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+
+        ######################################################################
+
+        for k in range(nb_epochs):
+            ############################################
+            # Train
+
+            model.train()
+
+            acc_train_loss = 0.0
+
+            for input, targets in zip(
+                train_input.split(batch_size), train_targets.split(batch_size)
+            ):
+                output = model(input)
+                loss = torch.nn.functional.cross_entropy(output, targets)
+                acc_train_loss += loss.item() * input.size(0)
+
+                optimizer.zero_grad()
+                loss.backward()
+                optimizer.step()
+
+            ############################################
+            # Test
+
+            model.eval()
+
+            nb_test_errors = 0
+            for input, targets in zip(
+                test_input.split(batch_size), test_targets.split(batch_size)
+            ):
+                wta = model(input).argmax(1)
+                nb_test_errors += (wta != targets).long().sum()
+            test_error = nb_test_errors / test_input.size(0)
+
+            if (k + 1) % 10 == 0:
+                print(
+                    f"loss {k+1} {acc_train_loss/train_input.size(0)} {test_error*100:.02f}%"
+                )
+                sys.stdout.flush()
+
+        ######################################################################
+
+        print(
+            f"final_loss {nb_hidden} {linear_layer} {acc_train_loss/train_input.size(0)} {test_error*100} %"
+        )
+        sys.stdout.flush()