loss_per_token = F.cross_entropy(
output.transpose(1, 2), targets, reduction="none"
)
- loss = loss_per_token.mean()
+ loss = loss_per_token.mean() + model.loss
acc_train_loss += loss.item() * input.size(0)
loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
######################################################################
+if args.test == "quant":
+ nb_bits = 8
+ for model in models:
+ model.trunk.insert(
+ 12,
+ mygpt.CacheWrapper(
+ mygpt.RandomBypass(
+ nn.Sequential(
+ nn.Linear(args.dim_model, nb_bits),
+ mygpt.BSQ(nb_bits),
+ nn.Linear(nb_bits, args.dim_model),
+ ),
+ 0.1,
+ )
+ ),
+ )
+
+######################################################################
+
current_epoch = 0
if args.resume:
######################################################################
+
+class BSQ(nn.Module):
+ def __init__(self, L):
+ super().__init__()
+ self.L = L
+
+ def forward(self, input, indexes=False):
+ norm = input.pow(2).sum(dim=1, keepdim=True).sqrt()
+ u = input / norm
+
+ if indexes:
+ return ((u >= 0).long() * (2 ** torch.arange(self.L))[None, :]).sum(dim=1)
+
+ hat_u = 1 / math.sqrt(self.L) * (2 * (u >= 0).float() - 1)
+ if self.training:
+ self.loss += u.mean(dim=0).tanh().pow(2).mean()
+ return hat_u + u - u.detach()
+ else:
+ return hat_u
+
+
+class RandomBypass(nn.Module):
+ def __init__(self, m, p):
+ super().__init__()
+ self.m = m
+ self.p = p
+
+ def forward(self, x):
+ y = self.m(x)
+
+ if self.training:
+ u = (torch.rand(x.size(0), device=x.device) <= self.p).long()[:, None]
+ return (u * x.flatten(1) + (1 - u) * y.flatten(1)).reshape(x.size())
+ else:
+ return y
+
+
+######################################################################
+
# A BracketedSequence is a BxTx... tensor with a first and a nb time
# steps to compute.
m.weight.fill_(1.0)
def forward(self, bs):
+ for m in self.modules():
+ m.loss = 0
+
bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
bs = self.embedding(bs)
bs = self.trunk(bs)
bs = self.readout(bs)
bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
+
+ for m in self.modules():
+ self.loss += m.loss
+
return bs
def encode(self, bs):