# Written by Francois Fleuret <francois@fleuret.org>
+# This is an implementation from scratch of a "GPT", that is a model
+# composed of several causal self-attention blocks. It is equipped
+# with a caching mechanism for keys and values to avoid a O(N^3) cost
+# for auto-regression.
+
import math
import torch
######################################################################
-class WithResidual(nn.Module):
- def __init__(self, *f):
+class BSQ(nn.Module):
+ def __init__(self, L):
super().__init__()
- self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+ self.L = L
- def forward(self, bs):
- bs.x = bs.x + self.f(bs).x
- return bs
+ def forward(self, input, indexes=False):
+ norm = input.pow(2).sum(dim=2, keepdim=True).sqrt()
+ u = input / norm
+
+ if indexes:
+ return ((u >= 0).long() * (2 ** torch.arange(self.L))[None, :]).sum(dim=1)
+
+ hat_u = 1 / math.sqrt(self.L) * (2 * (u >= 0).float() - 1)
+ if self.training:
+ self.loss += u.mean(dim=0).tanh().pow(2).mean()
+ return hat_u + u - u.detach()
+ else:
+ return hat_u
+
+
+class RandomBypass(nn.Module):
+ def __init__(self, m, p):
+ super().__init__()
+ self.m = m
+ self.p = p
+
+ def forward(self, x):
+ y = self.m(x)
+
+ if self.training:
+ u = (torch.rand(x.size(0), device=x.device) <= self.p).long()[:, None]
+ return (u * x.flatten(1) + (1 - u) * y.flatten(1)).reshape(x.size())
+ else:
+ return y
######################################################################
def slice(self):
return self.x[:, self.first : self.first + self.nb]
+ def complete(self):
+ return self.first == 0 and self.nb == self.x.size(1)
+
######################################################################
else:
self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
- bs.x = self.cache_y
+ return BracketedSequence(self.cache_y, bs.first, bs.nb)
- return bs
+
+##############################
+
+
+class WithResidual(nn.Module):
+ def __init__(self, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+ def forward(self, bs):
+ return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
##############################
bs.slice() + self.pe[bs.first : bs.first + bs.nb]
)
- bs.x = self.cache_y
+ return BracketedSequence(self.cache_y, bs.first, bs.nb)
- return bs
+
+##############################
+
+
+class EncoderHead(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.fc = nn.Linear(dim_in, dim_out)
+
+ def forward(self, bs):
+ z = self.fc(bs.x).mean(dim=1)
+ return z, bs.x.shape
+
+
+class DecoderBottom(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.fc = nn.Linear(dim_in, dim_out)
+
+ def forward(self, z_shape):
+ z, shape = z_shape
+ y = self.fc(z)[:, None, :].expand(shape)
+ return BracketedSequence(y)
##############################
class QKVAttention(nn.Module):
def __init__(
- self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+ self,
+ dim_in,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ compute_attzero=None,
+ attention_dropout=0.0,
):
super().__init__()
def randw(*d):
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
- self.causal = causal
+ self.compute_attzero = compute_attzero
self.attention_dropout = attention_dropout
+ self.record_attention = False
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
q = torch.einsum(
"ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
)
+
self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
"ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
)
"nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
) / math.sqrt(self.w_q.size(1))
- if self.causal:
+ if self.compute_attzero is not None:
if bs_q.first == 0:
- self.cache_attzero = (
- torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
- < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
- )
+ self.cache_attzero = self.compute_attzero(
+ torch.arange(x_q.size(1), device=q.device)[:, None],
+ torch.arange(x_q.size(1), device=q.device)[None, :],
+ )[None, None, :, :]
a = a.masked_fill(
self.cache_attzero[
:, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
)
a = a.softmax(dim=3)
+
+ if self.record_attention:
+ self.a = a
+
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum(
self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
- bs_q.x = self.cache_y
+ return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
+
+
+##############################
+
- return bs_q
+class NoiseInjector(nn.Module):
+ def __init__(self, identifier=None):
+ super().__init__()
+ self.noise_std = 0.0
+ self.identifier = identifier
+
+ def forward(self, x):
+ if self.noise_std > 0:
+ x = x * (
+ 1 - 2 * (torch.rand(x.size(), device=x.device) < self.noise_std).long()
+ )
+ return x
##############################
dim_hidden,
nb_heads,
nb_blocks,
- causal=False,
+ compute_attzero=None,
+ autoencoder_dim=-1,
dropout=0.0,
len_max=1e5,
):
assert dim_model % nb_heads == 0
+ self.temperature = 1.0
+
self.embedding = nn.Sequential(
CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
AddPositionalEncoding(len_max),
for b in range(nb_blocks):
trunk_blocks += [
WithResidual(
- CacheWrapper(nn.LayerNorm((dim_model,))),
+ CacheWrapper(
+ nn.LayerNorm((dim_model,)),
+ NoiseInjector(identifier=("attention", b)),
+ ),
QKVAttention(
dim_in=dim_model,
dim_qk=dim_keys,
dim_v=dim_model // nb_heads,
nb_heads=nb_heads,
- causal=causal,
+ compute_attzero=compute_attzero,
attention_dropout=dropout,
),
),
WithResidual(
CacheWrapper(
nn.LayerNorm((dim_model,)),
+ NoiseInjector(identifier=("ffw", b)),
nn.Linear(in_features=dim_model, out_features=dim_hidden),
nn.ReLU(),
nn.Linear(in_features=dim_hidden, out_features=dim_model),
nn.Linear(in_features=dim_model, out_features=vocabulary_size)
)
+ # -------------------------------------------------------
+ if autoencoder_dim > 0:
+ self.encoder = nn.Sequential(
+ *(
+ trunk_blocks[: nb_blocks // 2]
+ + [EncoderHead(dim_model, autoencoder_dim)]
+ )
+ )
+
+ self.decoder = nn.Sequential(
+ *(
+ [
+ DecoderBottom(autoencoder_dim, dim_model),
+ AddPositionalEncoding(len_max),
+ ]
+ + trunk_blocks[nb_blocks // 2 :]
+ )
+ )
+ # -------------------------------------------------------
+
with torch.no_grad():
for m in self.modules():
if isinstance(m, nn.Embedding):
m.weight.fill_(1.0)
def forward(self, bs):
- bs.x = F.pad(bs.x, (1, -1))
+ for m in self.modules():
+ m.loss = 0
+
+ bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
bs = self.embedding(bs)
bs = self.trunk(bs)
bs = self.readout(bs)
+ bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
+
+ for m in self.modules():
+ self.loss += m.loss
+
+ return bs
+
+ def encode(self, bs):
+ bs = self.embedding(bs)
+ z = self.encoder(bs)
+ return z
+
+ def decode(self, z_shape):
+ bs = self.decoder(z_shape)
+ bs = self.readout(bs)
return bs
+ def partial_forward(self, bs, start_layer=None, end_layer=None):
+ if start_layer is None:
+ # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
+ bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+ bs = self.embedding(bs)
+ if end_layer is not None:
+ return self.trunk[:end_layer](bs)
+ else:
+ bs = self.trunk(bs)
+ bs = self.readout(bs)
+ return bs
+ else:
+ bs = self.trunk[start_layer:](bs)
+ bs = self.trunk(bs)
+ bs = self.readout(bs)
+ return bs
+
+ def reset_transformations(self):
+ self.temperature = 1.0
+ for m in self.modules():
+ if isinstance(m, NoiseInjector):
+ m.noise_std = 0.0
+
+ def set_noise_injection(self, noise_std, identifier=None):
+ for m in self.modules():
+ if isinstance(m, NoiseInjector):
+ if identifier is None or identifier == m.identifier:
+ m.noise_std = noise_std
+
+ def record_attention(self, v=True):
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ m.record_attention = v
+
+ def retrieve_attention(self):
+ a = []
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ a.append(m.a)
+ return a
+
######################################################################
if __name__ == "__main__":
print("Basic check.")
- vocabulary_size = 10
- x = torch.randint(vocabulary_size, (9, 7))
+ vocabulary_size = 3
+ x = torch.randint(vocabulary_size, (1, 5))
model = MyGPT(
vocabulary_size=vocabulary_size,
- dim_model=18,
- dim_keys=50,
- dim_hidden=100,
+ dim_model=4,
+ dim_keys=2,
+ dim_hidden=2,
nb_heads=2,
- nb_blocks=1,
+ nb_blocks=2,
dropout=0.1,
)
model.eval()
-
y1 = model(BracketedSequence(x)).x
-
y2 = torch.randn_like(y1)
for s in range(x.size(1)):
z = model(BracketedSequence(x, s, 1))
- y2[:, s] = z.x[:, s]
+ y2[:, s] = z.slice()
- # print(y1.max(dim = 2).values)
- # print(y2.max(dim = 2).values)
print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
######################################################################