Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 9 Aug 2024 22:24:39 +0000 (00:24 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 9 Aug 2024 22:24:39 +0000 (00:24 +0200)
main.py
mygpt.py

diff --git a/main.py b/main.py
index 3196fbd..c4dcfb2 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -446,7 +446,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
         loss_per_token = F.cross_entropy(
             output.transpose(1, 2), targets, reduction="none"
         )
-        loss = loss_per_token.mean()
+        loss = loss_per_token.mean() + model.loss
         acc_train_loss += loss.item() * input.size(0)
 
         loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
@@ -1061,6 +1061,25 @@ for k in range(args.nb_gpts):
 
 ######################################################################
 
+if args.test == "quant":
+    nb_bits = 8
+    for model in models:
+        model.trunk.insert(
+            12,
+            mygpt.CacheWrapper(
+                mygpt.RandomBypass(
+                    nn.Sequential(
+                        nn.Linear(args.dim_model, nb_bits),
+                        mygpt.BSQ(nb_bits),
+                        nn.Linear(nb_bits, args.dim_model),
+                    ),
+                    0.1,
+                )
+            ),
+        )
+
+######################################################################
+
 current_epoch = 0
 
 if args.resume:
index 2706143..7827757 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -19,6 +19,45 @@ from torch.nn import functional as F
 
 ######################################################################
 
+
+class BSQ(nn.Module):
+    def __init__(self, L):
+        super().__init__()
+        self.L = L
+
+    def forward(self, input, indexes=False):
+        norm = input.pow(2).sum(dim=1, keepdim=True).sqrt()
+        u = input / norm
+
+        if indexes:
+            return ((u >= 0).long() * (2 ** torch.arange(self.L))[None, :]).sum(dim=1)
+
+        hat_u = 1 / math.sqrt(self.L) * (2 * (u >= 0).float() - 1)
+        if self.training:
+            self.loss += u.mean(dim=0).tanh().pow(2).mean()
+            return hat_u + u - u.detach()
+        else:
+            return hat_u
+
+
+class RandomBypass(nn.Module):
+    def __init__(self, m, p):
+        super().__init__()
+        self.m = m
+        self.p = p
+
+    def forward(self, x):
+        y = self.m(x)
+
+        if self.training:
+            u = (torch.rand(x.size(0), device=x.device) <= self.p).long()[:, None]
+            return (u * x.flatten(1) + (1 - u) * y.flatten(1)).reshape(x.size())
+        else:
+            return y
+
+
+######################################################################
+
 # A BracketedSequence is a BxTx... tensor with a first and a nb time
 # steps to compute.
 
@@ -328,11 +367,18 @@ class MyGPT(nn.Module):
                     m.weight.fill_(1.0)
 
     def forward(self, bs):
+        for m in self.modules():
+            m.loss = 0
+
         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
         bs = self.embedding(bs)
         bs = self.trunk(bs)
         bs = self.readout(bs)
         bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
+
+        for m in self.modules():
+            self.loss += m.loss
+
         return bs
 
     def encode(self, bs):