def slice(self):
return self.x[:, self.first : self.first + self.nb]
+ def complete(self):
+ return self.first == 0 and self.nb == self.x.size(1)
+
######################################################################
else:
self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
- bs.x = self.cache_y
-
- return bs
+ return BracketedSequence(self.cache_y, bs.first, bs.nb)
##############################
self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
def forward(self, bs):
- bs.x = bs.x + self.f(bs).x
- return bs
+ return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
##############################
bs.slice() + self.pe[bs.first : bs.first + bs.nb]
)
- bs.x = self.cache_y
-
- return bs
+ return BracketedSequence(self.cache_y, bs.first, bs.nb)
##############################
class QKVAttention(nn.Module):
def __init__(
- self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+ self,
+ dim_in,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ causal=False,
+ attention_dropout=0.0,
):
super().__init__()
self.causal = causal
self.attention_dropout = attention_dropout
+ self.record_attention = False
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
def forward(self, bs_q):
x_q = bs_q.x
+ assert (
+ self.causal or bs_q.complete()
+ ), "Partial evaluation is only possible for causal models"
+
if bs_q.first == 0:
self.cache_k = x_q.new_zeros(
x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
q = torch.einsum(
"ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
)
+
self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
"ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
)
)
a = a.softmax(dim=3)
+
+ if self.record_attention:
+ self.a = a
+
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum(
self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
- bs_q.x = self.cache_y
-
- return bs_q
+ return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
##############################
m.weight.fill_(1.0)
def forward(self, bs):
- bs.x = F.pad(bs.x, (1, -1))
+ # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
+ bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
bs = self.embedding(bs)
bs = self.trunk(bs)
bs = self.readout(bs)
return bs
+ # ar_mask is a tensor with 0s and 1s, of same shape as input, with
+ # 1s where tokens should be generated. The others are kept
+ # unchanged.
+
+ def masked_inplace_autoregression(
+ self,
+ input,
+ ar_mask,
+ summed_logits,
+ temperature=1.0,
+ deterministic_synthesis=False,
+ forbidden_tokens=None,
+ forced_biases=None,
+ ):
+ to_generate = (ar_mask.sum(0) > 0).nonzero()
+
+ if to_generate.min() > 0:
+ self(
+ BracketedSequence(input, 0, to_generate.min())
+ ) # Needed to initialize the model's cache
+ for s in range(to_generate.min(), to_generate.max() + 1):
+ output = self(BracketedSequence(input, s, 1)).x
+
+ logits = output[:, s]
+
+ logits = (logits / temperature).log_softmax(dim=-1)
+
+ if forbidden_tokens is not None:
+ logits = logits.masked_fill(forbidden_tokens, float("-inf"))
+
+ if forced_biases is not None:
+ logits = logits + forced_biases[None, :]
+
+ if deterministic_synthesis:
+ t_next = logits.argmax(-1)
+ else:
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ t_next = dist.sample()
+ if summed_logits is not None:
+ summed_logits += logits[torch.arange(t_next.size(0)), t_next].sum(
+ dim=-1
+ )
+
+ input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+
+ def record_attention(self, v=True):
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ m.record_attention = v
+
+ def retrieve_attention(self):
+ a = []
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ a.append(m.a)
+ return a
+
######################################################################
if __name__ == "__main__":
print("Basic check.")
- vocabulary_size = 10
- x = torch.randint(vocabulary_size, (9, 7))
+ vocabulary_size = 3
+ x = torch.randint(vocabulary_size, (1, 5))
model = MyGPT(
vocabulary_size=vocabulary_size,
- dim_model=18,
- dim_keys=50,
- dim_hidden=100,
+ dim_model=4,
+ dim_keys=2,
+ dim_hidden=2,
nb_heads=2,
- nb_blocks=1,
+ nb_blocks=2,
dropout=0.1,
+ causal=True,
)
model.eval()
-
y1 = model(BracketedSequence(x)).x
-
y2 = torch.randn_like(y1)
for s in range(x.size(1)):
z = model(BracketedSequence(x, s, 1))
- y2[:, s] = z.x[:, s]
+ y2[:, s] = z.slice()
- # print(y1.max(dim = 2).values)
- # print(y2.max(dim = 2).values)
print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
######################################################################