From 67a01a0a20ab174a43f3c1550e03988a153c8351 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 10 Aug 2024 00:24:39 +0200 Subject: [PATCH] Update. --- main.py | 21 ++++++++++++++++++++- mygpt.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 3196fbd..c4dcfb2 100755 --- a/main.py +++ b/main.py @@ -446,7 +446,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): loss_per_token = F.cross_entropy( output.transpose(1, 2), targets, reduction="none" ) - loss = loss_per_token.mean() + loss = loss_per_token.mean() + model.loss acc_train_loss += loss.item() * input.size(0) loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1) @@ -1061,6 +1061,25 @@ for k in range(args.nb_gpts): ###################################################################### +if args.test == "quant": + nb_bits = 8 + for model in models: + model.trunk.insert( + 12, + mygpt.CacheWrapper( + mygpt.RandomBypass( + nn.Sequential( + nn.Linear(args.dim_model, nb_bits), + mygpt.BSQ(nb_bits), + nn.Linear(nb_bits, args.dim_model), + ), + 0.1, + ) + ), + ) + +###################################################################### + current_epoch = 0 if args.resume: diff --git a/mygpt.py b/mygpt.py index 2706143..7827757 100755 --- a/mygpt.py +++ b/mygpt.py @@ -19,6 +19,45 @@ from torch.nn import functional as F ###################################################################### + +class BSQ(nn.Module): + def __init__(self, L): + super().__init__() + self.L = L + + def forward(self, input, indexes=False): + norm = input.pow(2).sum(dim=1, keepdim=True).sqrt() + u = input / norm + + if indexes: + return ((u >= 0).long() * (2 ** torch.arange(self.L))[None, :]).sum(dim=1) + + hat_u = 1 / math.sqrt(self.L) * (2 * (u >= 0).float() - 1) + if self.training: + self.loss += u.mean(dim=0).tanh().pow(2).mean() + return hat_u + u - u.detach() + else: + return hat_u + + +class RandomBypass(nn.Module): + def __init__(self, m, p): + super().__init__() + self.m = m + self.p = p + + def forward(self, x): + y = self.m(x) + + if self.training: + u = (torch.rand(x.size(0), device=x.device) <= self.p).long()[:, None] + return (u * x.flatten(1) + (1 - u) * y.flatten(1)).reshape(x.size()) + else: + return y + + +###################################################################### + # A BracketedSequence is a BxTx... tensor with a first and a nb time # steps to compute. @@ -328,11 +367,18 @@ class MyGPT(nn.Module): m.weight.fill_(1.0) def forward(self, bs): + for m in self.modules(): + m.loss = 0 + bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) bs = self.embedding(bs) bs = self.trunk(bs) bs = self.readout(bs) bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature + + for m in self.modules(): + self.loss += m.loss + return bs def encode(self, bs): -- 2.39.5