# Written by Francois Fleuret <francois@fleuret.org>
+# This is an implementation from scratch of a "GPT", that is a model
+# composed of several causal self-attention blocks. It is equipped
+# with a caching mechanism for keys and values to avoid a O(N^3) cost
+# for auto-regression.
+
import math
import torch
######################################################################
-
-class WithResidual(nn.Module):
- def __init__(self, *f):
- super().__init__()
- self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
- def forward(self, bs):
- bs.x = bs.x + self.f(bs).x
- return bs
-
-
-######################################################################
-
# A BracketedSequence is a BxTx... tensor with a first and a nb time
# steps to compute.
def slice(self):
return self.x[:, self.first : self.first + self.nb]
+ def complete(self):
+ return self.first == 0 and self.nb == self.x.size(1)
+
######################################################################
else:
self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
- bs.x = self.cache_y
+ return BracketedSequence(self.cache_y, bs.first, bs.nb)
- return bs
+
+##############################
+
+
+class WithResidual(nn.Module):
+ def __init__(self, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+ def forward(self, bs):
+ return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
##############################
bs.slice() + self.pe[bs.first : bs.first + bs.nb]
)
- bs.x = self.cache_y
-
- return bs
+ return BracketedSequence(self.cache_y, bs.first, bs.nb)
##############################
class QKVAttention(nn.Module):
def __init__(
- self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+ self,
+ dim_in,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ causal=False,
+ attention_dropout=0.0,
):
super().__init__()
self.causal = causal
self.attention_dropout = attention_dropout
+ self.record_attention = False
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
def forward(self, bs_q):
x_q = bs_q.x
+ assert (
+ self.causal or bs_q.complete()
+ ), "Partial evaluation is only possible for causal models"
+
if bs_q.first == 0:
self.cache_k = x_q.new_zeros(
x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
q = torch.einsum(
"ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
)
+
self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
"ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
)
)
a = a.softmax(dim=3)
+
+ if self.record_attention:
+ self.a = a
+
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum(
self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
- bs_q.x = self.cache_y
+ return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
+
+
+##############################
+
+
+class NoiseInjector(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.noise_std = 0.0
+
+ def forward(self, x):
+ if self.noise_std > 0:
+ x = x + torch.randn(x.size(), device=x.device) * self.noise_std
+ return x
+
- return bs_q
+def set_noise_injection(model, noise_std):
+ for m in model.modules():
+ if isinstance(m, NoiseInjector):
+ m.noise_std = noise_std
##############################
for b in range(nb_blocks):
trunk_blocks += [
WithResidual(
- CacheWrapper(nn.LayerNorm((dim_model,))),
+ CacheWrapper(
+ nn.LayerNorm((dim_model,)),
+ NoiseInjector(),
+ ),
QKVAttention(
dim_in=dim_model,
dim_qk=dim_keys,
WithResidual(
CacheWrapper(
nn.LayerNorm((dim_model,)),
+ NoiseInjector(),
nn.Linear(in_features=dim_model, out_features=dim_hidden),
nn.ReLU(),
nn.Linear(in_features=dim_hidden, out_features=dim_model),
m.weight.fill_(1.0)
def forward(self, bs):
- bs.x = F.pad(bs.x, (1, -1))
+ # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
+ bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
bs = self.embedding(bs)
bs = self.trunk(bs)
bs = self.readout(bs)
return bs
+ def record_attention(self, v=True):
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ m.record_attention = v
+
+ def retrieve_attention(self):
+ a = []
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ a.append(m.a)
+ return a
+
######################################################################
if __name__ == "__main__":
print("Basic check.")
- vocabulary_size = 10
- x = torch.randint(vocabulary_size, (9, 7))
+ vocabulary_size = 3
+ x = torch.randint(vocabulary_size, (1, 5))
model = MyGPT(
vocabulary_size=vocabulary_size,
- dim_model=18,
- dim_keys=50,
- dim_hidden=100,
+ dim_model=4,
+ dim_keys=2,
+ dim_hidden=2,
nb_heads=2,
- nb_blocks=1,
+ nb_blocks=2,
dropout=0.1,
+ causal=True,
)
model.eval()
-
y1 = model(BracketedSequence(x)).x
-
y2 = torch.randn_like(y1)
for s in range(x.size(1)):
z = model(BracketedSequence(x, s, 1))
- y2[:, s] = z.x[:, s]
+ y2[:, s] = z.slice()
- # print(y1.max(dim = 2).values)
- # print(y2.max(dim = 2).values)
print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
######################################################################