self.w_v = randw(nb_heads, dim_v, dim_in)
self.w_o = randw(dim_v * nb_heads, dim_in)
- def forward(self, bs_q):
+ def forward(self, bs_q, bs_kv=None):
+ if bs_kv is None:
+ bs_kv = bs_q
+
x_q = bs_q.x
+ x_kv = bs_kv.x
- if bs_q.first == 0:
- self.cache_k = x_q.new_zeros(
- x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
+ if bs_kv.first == 0:
+ self.cache_k = x_kv.new_zeros(
+ x_kv.size(0), self.w_k.size(0), x_kv.size(1), self.w_k.size(1)
)
- self.cache_v = x_q.new_zeros(
- x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
+ self.cache_v = x_kv.new_zeros(
+ x_kv.size(0), self.w_v.size(0), x_kv.size(1), self.w_v.size(1)
)
+
+ if bs_q.first == 0:
self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
q = torch.einsum(
"ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
)
- self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
- "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
+ self.cache_k[:, :, bs_kv.first : bs_kv.first + bs_kv.nb] = torch.einsum(
+ "ntc,hdc->nhtd", x_kv[:, bs_kv.first : bs_kv.first + bs_kv.nb], self.w_k
)
- self.cache_v[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
- "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_v
+ self.cache_v[:, :, bs_kv.first : bs_kv.first + bs_kv.nb] = torch.einsum(
+ "ntc,hdc->nhtd", x_kv[:, bs_kv.first : bs_kv.first + bs_kv.nb], self.w_v
)
a = torch.einsum(
- "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
+ "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_kv.first + bs_kv.nb]
) / math.sqrt(self.w_q.size(1))
if self.compute_attzero is not None:
if bs_q.first == 0:
self.cache_attzero = self.compute_attzero(
torch.arange(x_q.size(1), device=q.device)[:, None],
- torch.arange(x_q.size(1), device=q.device)[None, :],
+ torch.arange(x_kv.size(1), device=q.device)[None, :],
)[None, None, :, :]
a = a.masked_fill(
self.cache_attzero[
- :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
+ :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_kv.first + bs_kv.nb
],
float("-inf"),
)
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum(
- "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_q.first + bs_q.nb]
+ "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_kv.first + bs_kv.nb]
).flatten(2)
self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
##############################
+class BlockSummarizer(nn.Module):
+ def __init__(self, nb_blocks, nb_tokens, dim_keys, dim_model):
+ self.nb_blocks = nb_blocks
+ self.static_q = nn.Parameter(nb_blocks - 1, nb_tokens, dim_keys)
+
+ def compute_block_attzero(t_q, t_k):
+ block_size = t_q.size(0)
+ return (t_q // block_size) <= (t_k // block_size)
+
+ self.qkv = QKVAttention(
+ dim_in=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ compute_attzero=compute_attzero,
+ attention_dropout=dropout,
+ )
+
+ def forward(self, bs):
+ pass
+
+
+class ShiftByOne(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, bs):
+ return BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+
+
class MyGPT(nn.Module):
def __init__(
self,
nb_heads,
nb_blocks,
compute_attzero=None,
- autoencoder_dim=-1,
dropout=0.0,
len_max=1e5,
):
self.temperature = 1.0
+ self.shifter = ShiftByOne()
+
self.embedding = nn.Sequential(
CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
- AddPositionalEncoding(len_max),
)
+ self.positional_encoding = AddPositionalEncoding(len_max)
+
trunk_blocks = []
for b in range(nb_blocks):
nn.Linear(in_features=dim_model, out_features=vocabulary_size)
)
- # -------------------------------------------------------
- if autoencoder_dim > 0:
- self.encoder = nn.Sequential(
- *(
- trunk_blocks[: nb_blocks // 2]
- + [EncoderHead(dim_model, autoencoder_dim)]
- )
- )
-
- self.decoder = nn.Sequential(
- *(
- [
- DecoderBottom(autoencoder_dim, dim_model),
- AddPositionalEncoding(len_max),
- ]
- + trunk_blocks[nb_blocks // 2 :]
- )
- )
- # -------------------------------------------------------
-
with torch.no_grad():
for m in self.modules():
if isinstance(m, nn.Embedding):
for m in self.modules():
m.loss = 0
- bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+ bs = self.shifter(bs)
bs = self.embedding(bs)
+ bs = self.positional_encoding(bs)
bs = self.trunk(bs)
bs = self.readout(bs)
bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
return bs
- def encode(self, bs):
- bs = self.embedding(bs)
- z = self.encoder(bs)
- return z
-
- def decode(self, z_shape):
- bs = self.decoder(z_shape)
- bs = self.readout(bs)
- return bs
-
- def partial_forward(self, bs, start_layer=None, end_layer=None):
- if start_layer is None:
- # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
- bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
- bs = self.embedding(bs)
- if end_layer is not None:
- return self.trunk[:end_layer](bs)
- else:
- bs = self.trunk(bs)
- bs = self.readout(bs)
- return bs
- else:
- bs = self.trunk[start_layer:](bs)
- bs = self.trunk(bs)
- bs = self.readout(bs)
- return bs
-
def reset_transformations(self):
self.temperature = 1.0
for m in self.modules():