######################################################################
+
+class BSQ(nn.Module):
+ def __init__(self, L):
+ super().__init__()
+ self.L = L
+
+ def forward(self, input, indexes=False):
+ norm = input.pow(2).sum(dim=2, keepdim=True).sqrt()
+ u = input / norm
+
+ if indexes:
+ return ((u >= 0).long() * (2 ** torch.arange(self.L))[None, :]).sum(dim=1)
+
+ hat_u = 1 / math.sqrt(self.L) * (2 * (u >= 0).float() - 1)
+ if self.training:
+ self.loss += u.mean(dim=0).tanh().pow(2).mean()
+ return hat_u + u - u.detach()
+ else:
+ return hat_u
+
+
+class RandomBypass(nn.Module):
+ def __init__(self, m, p):
+ super().__init__()
+ self.m = m
+ self.p = p
+
+ def forward(self, x):
+ y = self.m(x)
+
+ if self.training:
+ u = (torch.rand(x.size(0), device=x.device) <= self.p).long()[:, None]
+ return (u * x.flatten(1) + (1 - u) * y.flatten(1)).reshape(x.size())
+ else:
+ return y
+
+
+######################################################################
+
# A BracketedSequence is a BxTx... tensor with a first and a nb time
# steps to compute.
##############################
+class EncoderHead(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.fc = nn.Linear(dim_in, dim_out)
+
+ def forward(self, bs):
+ z = self.fc(bs.x).mean(dim=1)
+ return z, bs.x.shape
+
+
+class DecoderBottom(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.fc = nn.Linear(dim_in, dim_out)
+
+ def forward(self, z_shape):
+ z, shape = z_shape
+ y = self.fc(z)[:, None, :].expand(shape)
+ return BracketedSequence(y)
+
+
+##############################
+
+
class QKVAttention(nn.Module):
def __init__(
self,
dim_qk,
dim_v,
nb_heads=1,
- causal=False,
+ compute_attzero=None,
attention_dropout=0.0,
):
super().__init__()
def randw(*d):
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
- self.causal = causal
+ self.compute_attzero = compute_attzero
self.attention_dropout = attention_dropout
self.record_attention = False
def forward(self, bs_q):
x_q = bs_q.x
- assert (
- self.causal or bs_q.complete()
- ), "Partial evaluation is only possible for causal models"
-
if bs_q.first == 0:
self.cache_k = x_q.new_zeros(
x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
"nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
) / math.sqrt(self.w_q.size(1))
- if self.causal:
+ if self.compute_attzero is not None:
if bs_q.first == 0:
- self.cache_attzero = (
- torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
- < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
- )
+ self.cache_attzero = self.compute_attzero(
+ torch.arange(x_q.size(1), device=q.device)[:, None],
+ torch.arange(x_q.size(1), device=q.device)[None, :],
+ )[None, None, :, :]
a = a.masked_fill(
self.cache_attzero[
:, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
##############################
+class NoiseInjector(nn.Module):
+ def __init__(self, identifier=None):
+ super().__init__()
+ self.noise_std = 0.0
+ self.identifier = identifier
+
+ def forward(self, x):
+ if self.noise_std > 0:
+ x = x * (
+ 1 - 2 * (torch.rand(x.size(), device=x.device) < self.noise_std).long()
+ )
+ return x
+
+
+##############################
+
+
class MyGPT(nn.Module):
def __init__(
self,
dim_hidden,
nb_heads,
nb_blocks,
- causal=False,
+ compute_attzero=None,
+ autoencoder_dim=-1,
dropout=0.0,
len_max=1e5,
):
assert dim_model % nb_heads == 0
+ self.temperature = 1.0
+
self.embedding = nn.Sequential(
CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
AddPositionalEncoding(len_max),
for b in range(nb_blocks):
trunk_blocks += [
WithResidual(
- CacheWrapper(nn.LayerNorm((dim_model,))),
+ CacheWrapper(
+ nn.LayerNorm((dim_model,)),
+ NoiseInjector(identifier=("attention", b)),
+ ),
QKVAttention(
dim_in=dim_model,
dim_qk=dim_keys,
dim_v=dim_model // nb_heads,
nb_heads=nb_heads,
- causal=causal,
+ compute_attzero=compute_attzero,
attention_dropout=dropout,
),
),
WithResidual(
CacheWrapper(
nn.LayerNorm((dim_model,)),
+ NoiseInjector(identifier=("ffw", b)),
nn.Linear(in_features=dim_model, out_features=dim_hidden),
nn.ReLU(),
nn.Linear(in_features=dim_hidden, out_features=dim_model),
nn.Linear(in_features=dim_model, out_features=vocabulary_size)
)
+ # -------------------------------------------------------
+ if autoencoder_dim > 0:
+ self.encoder = nn.Sequential(
+ *(
+ trunk_blocks[: nb_blocks // 2]
+ + [EncoderHead(dim_model, autoencoder_dim)]
+ )
+ )
+
+ self.decoder = nn.Sequential(
+ *(
+ [
+ DecoderBottom(autoencoder_dim, dim_model),
+ AddPositionalEncoding(len_max),
+ ]
+ + trunk_blocks[nb_blocks // 2 :]
+ )
+ )
+ # -------------------------------------------------------
+
with torch.no_grad():
for m in self.modules():
if isinstance(m, nn.Embedding):
m.weight.fill_(1.0)
def forward(self, bs):
- # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
+ for m in self.modules():
+ m.loss = 0
+
bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
bs = self.embedding(bs)
bs = self.trunk(bs)
bs = self.readout(bs)
+ bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
+
+ for m in self.modules():
+ self.loss += m.loss
+
return bs
- # ar_mask is a tensor with 0s and 1s, of same shape as input, with
- # 1s where tokens should be generated. The others are kept
- # unchanged.
+ def encode(self, bs):
+ bs = self.embedding(bs)
+ z = self.encoder(bs)
+ return z
- def masked_inplace_autoregression(
- self,
- input,
- ar_mask,
- temperature=1.0,
- deterministic_synthesis=False,
- forbidden_tokens=None,
- forced_biases=None,
- ):
- sum_logits = 0
- to_generate = (ar_mask.sum(0) > 0).nonzero()
- if to_generate.min() > 0:
- self(
- BracketedSequence(input, 0, to_generate.min())
- ) # Needed to initialize the model's cache
- for s in range(to_generate.min(), to_generate.max() + 1):
- output = self(BracketedSequence(input, s, 1)).x
- logits = output[:, s]
- if forbidden_tokens is not None:
- logits = logits.masked_fill(forbidden_tokens, float("-inf"))
- if forced_biases is not None:
- logits = logits + forced_biases[None, :]
- if deterministic_synthesis:
- t_next = logits.argmax(1)
+ def decode(self, z_shape):
+ bs = self.decoder(z_shape)
+ bs = self.readout(bs)
+ return bs
+
+ def partial_forward(self, bs, start_layer=None, end_layer=None):
+ if start_layer is None:
+ # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
+ bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+ bs = self.embedding(bs)
+ if end_layer is not None:
+ return self.trunk[:end_layer](bs)
else:
- dist = torch.distributions.categorical.Categorical(logits=logits)
- t_next = dist.sample()
- sum_logits += logits.log_softmax(dim=-1)[
- torch.arange(t_next.size(0)), t_next
- ].sum()
- input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+ bs = self.trunk(bs)
+ bs = self.readout(bs)
+ return bs
+ else:
+ bs = self.trunk[start_layer:](bs)
+ bs = self.trunk(bs)
+ bs = self.readout(bs)
+ return bs
- return sum_logits
+ def reset_transformations(self):
+ self.temperature = 1.0
+ for m in self.modules():
+ if isinstance(m, NoiseInjector):
+ m.noise_std = 0.0
+
+ def set_noise_injection(self, noise_std, identifier=None):
+ for m in self.modules():
+ if isinstance(m, NoiseInjector):
+ if identifier is None or identifier == m.identifier:
+ m.noise_std = noise_std
def record_attention(self, v=True):
for m in self.modules():
nb_heads=2,
nb_blocks=2,
dropout=0.1,
- causal=True,
)
model.eval()