def slice(self):
return self.x[:, self.first : self.first + self.nb]
+ def complete(self):
+ return self.first == 0 and self.nb == self.x.size(1)
+
######################################################################
else:
self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
- bs.x = self.cache_y
-
- return bs
+ return BracketedSequence(self.cache_y, bs.first, bs.nb)
##############################
self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
def forward(self, bs):
- bs.x = bs.x + self.f(bs).x
- return bs
+ return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
##############################
bs.slice() + self.pe[bs.first : bs.first + bs.nb]
)
- bs.x = self.cache_y
-
- return bs
+ return BracketedSequence(self.cache_y, bs.first, bs.nb)
##############################
class QKVAttention(nn.Module):
def __init__(
- self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+ self,
+ dim_in,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ causal=False,
+ attention_dropout=0.0,
):
super().__init__()
self.causal = causal
self.attention_dropout = attention_dropout
+ self.record_attention = False
self.w_q = randw(nb_heads, dim_qk, dim_in)
self.w_k = randw(nb_heads, dim_qk, dim_in)
def forward(self, bs_q):
x_q = bs_q.x
+ assert (
+ self.causal or bs_q.complete()
+ ), "Partial evaluation is only possible for causal models"
+
if bs_q.first == 0:
self.cache_k = x_q.new_zeros(
x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
q = torch.einsum(
"ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
)
+
self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
"ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
)
)
a = a.softmax(dim=3)
+
+ if self.record_attention:
+ self.a = a
+
a = F.dropout(a, self.attention_dropout, self.training)
y = torch.einsum(
self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
- bs_q.x = self.cache_y
+ return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
+
- return bs_q
+##############################
+
+
+class NoiseInjector(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.noise_std = 0.0
+
+ def forward(self, x):
+ if self.noise_std > 0:
+ x = x + torch.randn(x.size(), device=x.device) * self.noise_std
+ return x
+
+
+def set_noise_injection(model, noise_std):
+ for m in model.modules():
+ if isinstance(m, NoiseInjector):
+ m.noise_std = noise_std
##############################
for b in range(nb_blocks):
trunk_blocks += [
WithResidual(
- CacheWrapper(nn.LayerNorm((dim_model,))),
+ CacheWrapper(
+ nn.LayerNorm((dim_model,)),
+ NoiseInjector(),
+ ),
QKVAttention(
dim_in=dim_model,
dim_qk=dim_keys,
WithResidual(
CacheWrapper(
nn.LayerNorm((dim_model,)),
+ NoiseInjector(),
nn.Linear(in_features=dim_model, out_features=dim_hidden),
nn.ReLU(),
nn.Linear(in_features=dim_hidden, out_features=dim_model),
m.weight.fill_(1.0)
def forward(self, bs):
- bs.x = F.pad(bs.x, (1, -1))
+ # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
+ bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
bs = self.embedding(bs)
bs = self.trunk(bs)
bs = self.readout(bs)
return bs
- # ar_mask is a tensor with 0s and 1s, of same shape as input, with
- # 1s where tokens should be generated. The others are kept
- # unchanged.
+ def record_attention(self, v=True):
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ m.record_attention = v
- def masked_inplace_autoregression(
- self, input, ar_mask, forbidden_tokens=None, deterministic_synthesis=False
- ):
- to_generate = (ar_mask.sum(0) > 0).nonzero()
- if to_generate.min() > 0:
- self(
- BracketedSequence(input, 0, to_generate.min())
- ) # Needed to initialize the model's cache
- for s in range(to_generate.min(), to_generate.max() + 1):
- output = self(BracketedSequence(input, s, 1)).x
- logits = output[:, s]
- if forbidden_tokens is not None:
- logits = logits.masked_fill(forbidden_tokens, float("-inf"))
- if deterministic_synthesis:
- t_next = logits.argmax(1)
- else:
- dist = torch.distributions.categorical.Categorical(logits=logits)
- t_next = dist.sample()
- input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+ def retrieve_attention(self):
+ a = []
+ for m in self.modules():
+ if isinstance(m, QKVAttention):
+ a.append(m.a)
+ return a
######################################################################
if __name__ == "__main__":
print("Basic check.")
- vocabulary_size = 10
- x = torch.randint(vocabulary_size, (9, 7))
+ vocabulary_size = 3
+ x = torch.randint(vocabulary_size, (1, 5))
model = MyGPT(
vocabulary_size=vocabulary_size,
- dim_model=18,
- dim_keys=50,
- dim_hidden=100,
+ dim_model=4,
+ dim_keys=2,
+ dim_hidden=2,
nb_heads=2,
- nb_blocks=1,
+ nb_blocks=2,
dropout=0.1,
+ causal=True,
)
model.eval()
-
y1 = model(BracketedSequence(x)).x
-
y2 = torch.randn_like(y1)
for s in range(x.size(1)):
z = model(BracketedSequence(x, s, 1))
- y2[:, s] = z.x[:, s]
+ y2[:, s] = z.slice()
print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")