From 53d7745e661073bad93752ea41b0320312250954 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 26 Mar 2024 18:44:06 +0100 Subject: [PATCH] Update. --- bit_mlp.py | 120 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100755 bit_mlp.py diff --git a/bit_mlp.py b/bit_mlp.py new file mode 100755 index 0000000..85262b7 --- /dev/null +++ b/bit_mlp.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import os, sys +import torch, torchvision +from torch import nn + +lr, nb_epochs, batch_size = 2e-3, 50, 100 + +data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +###################################################################### + +train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True) +train_input = train_set.data.view(-1, 1, 28, 28).float() +train_targets = train_set.targets + +test_set = torchvision.datasets.MNIST(root=data_dir, train=False, download=True) +test_input = test_set.data.view(-1, 1, 28, 28).float() +test_targets = test_set.targets + +train_input, train_targets = train_input.to(device), train_targets.to(device) +test_input, test_targets = test_input.to(device), test_targets.to(device) + +mu, std = train_input.mean(), train_input.std() + +train_input.sub_(mu).div_(std) +test_input.sub_(mu).div_(std) + +###################################################################### + + +class QLinear(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.w = nn.Parameter(torch.randn(dim_out, dim_in)) + self.b = nn.Parameter(torch.randn(dim_out) * 1e-1) + + def quantize(self, z): + epsilon = 1e-3 + zr = z / (z.abs().mean() + epsilon) + zq = -(zr <= -0.5).long() + (zr >= 0.5).long() + if self.training: + return zq + z - z.detach() + else: + return zq.float() + + def forward(self, x): + return x @ self.quantize(self.w).t() + self.quantize(self.b) + + +###################################################################### + +for nb_hidden in [16, 32, 64, 128, 256, 512, 1024]: + for linear_layer in [nn.Linear, QLinear]: + model = nn.Sequential( + nn.Flatten(), + linear_layer(784, nb_hidden), + nn.ReLU(), + linear_layer(nb_hidden, 10), + ).to(device) + + nb_parameters = sum(p.numel() for p in model.parameters()) + + print(f"nb_parameters {nb_parameters}") + + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + ###################################################################### + + for k in range(nb_epochs): + ############################################ + # Train + + model.train() + + acc_train_loss = 0.0 + + for input, targets in zip( + train_input.split(batch_size), train_targets.split(batch_size) + ): + output = model(input) + loss = torch.nn.functional.cross_entropy(output, targets) + acc_train_loss += loss.item() * input.size(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + ############################################ + # Test + + model.eval() + + nb_test_errors = 0 + for input, targets in zip( + test_input.split(batch_size), test_targets.split(batch_size) + ): + wta = model(input).argmax(1) + nb_test_errors += (wta != targets).long().sum() + test_error = nb_test_errors / test_input.size(0) + + if (k + 1) % 10 == 0: + print( + f"loss {k+1} {acc_train_loss/train_input.size(0)} {test_error*100:.02f}%" + ) + sys.stdout.flush() + + ###################################################################### + + print( + f"final_loss {nb_hidden} {linear_layer} {acc_train_loss/train_input.size(0)} {test_error*100} %" + ) + sys.stdout.flush() -- 2.39.5