######################################################################
-def baseline(X, V):
+def baseline1(X, V):
Y = X.new(X.size())
W = V.new(V.size())
+
for t in range(X.size(1)):
if t == 0:
Y[:, t] = X[:, t]
W[:, t] = V[:, t]
else:
- m = (V[:, t] >= W[:, t - 1] - 1).long()
- Y[:, t] = m * X[:, t] + (1 - m) * Y[:, t - 1]
- W[:, t] = m * V[:, t] + (1 - m) * (W[:, t - 1] - 1)
+ m = (W[:, t - 1] - 1 >= V[:, t]).long()
+ W[:, t] = m * (W[:, t - 1] - 1) + (1 - m) * V[:, t]
+ Y[:, t] = m * Y[:, t - 1] + (1 - m) * (
+ X[:, t] * (1 + dv) + Y[:, t - 1] * dv0
+ )
+
+ return Y, W
+
+
+######################################################################
+
+
+def hs(x):
+ return x.sigmoid() # (x >= 0).float() + (x - x.detach()) * (x < 0).float()
+
+
+def baseline(X, V):
+ for t in range(X.size(1)):
+ if t == 0:
+ Y = X[:, t]
+ W = V[:, t]
+ else:
+ m = (W - 1 - V[:, t]).sigmoid()
+ # m = hs(W - 1 - V[:, t])
+ W = m * (W - 1) + (1 - m) * V[:, t]
+ Y = m * Y + (1 - m) * X[:, t]
return Y, W
Vrf = Vr[:, :T].view(Vr.size(0), Vr.size(1) // 2, 2)
# [:, :, 0] < [:, :, 1]
- dx = Xf[:, :, 1] - Xf[:, :, 1].detach()
+ dv0 = (Vf[:, :, 0] - Vf[:, :, 0].detach())[:, :, None]
dv = (Vf[:, :, 1] - Vf[:, :, 1].detach())[:, :, None]
m = (Vf[:, :, 0] - s >= Vf[:, :, 1]).long()
Vv = m * (Vf[:, :, 0] - s) + (1 - m) * Vf[:, :, 1]
m = m[:, :, None]
- Xx = m * Xf[:, :, 0] + (1 - m) * (Xf[:, :, 1] * (1 + dv) + dx)
+ Xx = m * Xf[:, :, 0] + (1 - m) * (Xf[:, :, 1] * (1 + dv) + Xf[:, :, 0] * dv0)
Xrf[:, :, 1], Vrf[:, :, 1] = pscan_diff(Xx, Vv, s * 2)
- Xr[:, 0] = X[:, 0]
- Vr[:, 0] = V[:, 0]
-
# [:, :-1, 1] < [:, 1:, 0]
- dx = Xf[:, 1:, 0] - Xf[:, 1:, 0].detach()
+ dv0 = (Vrf[:, :-1, 1] - Vrf[:, :-1, 1].detach())[:, :, None]
dv = (Vf[:, 1:, 0] - Vf[:, 1:, 0].detach())[:, :, None]
m = (Vrf[:, :-1, 1] - s >= Vf[:, 1:, 0]).long()
Vrf[:, 1:, 0] = m * (Vrf[:, :-1, 1] - s) + (1 - m) * Vf[:, 1:, 0]
m = m[:, :, None]
- Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * (Xf[:, 1:, 0] * (1 + dv) + dx)
+ Xrf[:, 1:, 0] = m * Xrf[:, :-1, 1] + (1 - m) * (
+ Xf[:, 1:, 0] * (1 + dv) + Xrf[:, :-1, 1] * dv0
+ )
+
+ Xr[:, 0] = X[:, 0]
+ Vr[:, 0] = V[:, 0]
if T < X.size(1):
# [:, -2] < [:, -1]
- dx = X[:, -1] - X[:, -1].detach()
+ dx = X[:, -2] - X[:, -2].detach()
dv = (V[:, -1] - V[:, -1].detach())[:, None]
m = (V[:, -2] - s >= V[:, -1]).long()
- Vr[:, -1] = m * (V[:, -2] - s) + (1 - m) * V[:, -1]
+ Vr[:, -1] = m * (Vr[:, -2] - s) + (1 - m) * V[:, -1]
m = m[:, None]
- Xr[:, -1] = m * X[:, -2] + (1 - m) * (X[:, -1] * (1 + dv) + dx)
+ Xr[:, -1] = m * Xr[:, -2] + (1 - m) * (X[:, -1] * (1 + dv) + dx)
return Xr, Vr
if __name__ == "__main__":
N = 1
- T = 513
- D = 2
+ T = 64
+ D = 128
- X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
- V = torch.rand(N, T, dtype=torch.float64) * 10
+ torch.autograd.set_detect_anomaly(True)
- X0, V0 = baseline(X, V)
+ for k in range(0):
+ X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+ V = torch.rand(N, T, dtype=torch.float64)
- # print("########### X0 V0 ###########################################")
- # print(V0)
- # print(X0)
+ X0, V0 = baseline(X, V)
- X1, V1 = pscan_diff(X, V)
+ # print("########### X0 V0 ###########################################")
+ # print(V0)
+ # print(X0)
- # print("########### X V ############################################")
- # print(V)
- # print(X)
+ X1, V1 = pscan_diff(X, V)
- print("ERROR", ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item())
+ # print("########### X V ############################################")
+ # print(V)
+ # print(X)
- exit(0)
+ error = ((X0 - X1).abs().max() + (V0 - V1).abs().max()).item()
+ if error > 0:
+ print("ERROR", error)
+ print(X0)
+ print(X1)
+ exit(0)
+
+ # exit(0)
# s = X1.sum()
# print(torch.autograd.grad(s, X))
# f.write(f"{V1[0,t].item()}\n")
Y = torch.randn(1, 1, D)
- X = torch.randn(
- N, T, D
- ) # * 0.1 + (torch.rand(N,T,1).sort(dim=1).indices==0).float() * Y
- V = torch.rand(N, T).requires_grad_()
+ X = torch.randn(N, T, D) * 0.1
+
+ m = (torch.rand(N, T, 1).sort(dim=1).indices == 0).float()
+ X = (1 - m) * X + m * Y
+ V = torch.rand(N, T) # + 100* m.squeeze(dim=-1)
+ V = V.requires_grad_()
- optimizer = torch.optim.SGD([V], lr=1e-2)
+ optimizer = torch.optim.SGD([V], lr=1e-1)
for k in range(1000):
- X1, V1 = X.clone(), V.clone()
- pscan(X, V, X1, V1)
- # X1=X1*(1+V1-V1.detach())[:,:,None]
- loss = (X1[:, -1:] - Y).pow(2).mean()
+ X1, V1 = baseline(X, V)
+ loss = (X1 - Y).pow(2).mean()
print(k, loss.item())
optimizer.zero_grad()
loss.backward()
self.caterpillar_height = caterpillar_height
self.attention_dropout = attention_dropout
- self.gate_dropout_proba = args.gate_dropout_proba
- self.gate_dropout_sync = args.gate_dropout_sync
- self.gate_dropout_replace = args.gate_dropout_replace
-
######################################################################
- self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1e-3)
+ self.w_G = randw(nb_heads, caterpillar_height, dim_model)
self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0))
self.w_K = randw(nb_heads, dim_qk, dim_model)
V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
- # V, K = blanket(V), blanket(K)
-
######################################################################
# Compute the recurrent state
G = G / G.sum(1, keepdim=True).clamp(min=1)
- # G_star = (1 - G).log().sum(1, keepdim=True).exp()
-
######################################################################
- def recurrence(G, V, K):
- # We prepare the arguments for the parallel scan
-
- A = 1 - G.sum(dim=1)
-
- gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
- gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
-
- # We start from cached values, which matters in inference
-
- init_rec_V = self.rec_V[:, :, t0 - L : t0]
- init_rec_K = self.rec_K[:, :, t0 - L : t0]
-
- # Here there is a trick: Since the stack at position t is
- # computed by updating that at position t-L, the parallel
- # scan operates with a period of L. To do so we split the
- # sequence indexing in two axes, the second of size L, and
- # run the parallel scan using the first as the sequence index.
-
- A = A.unflatten(2, (-1, L))
- gated_V = gated_V.unflatten(2, (-1, L))
- gated_K = gated_K.unflatten(2, (-1, L))
-
- next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
- next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)
+ A = 1 - G.sum(dim=1)
- return next_V, next_K
+ gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
+ gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
- #################################################################
+ # We start from cached values, which matters in inference
- next_V, next_K = recurrence(G, V, K)
+ init_rec_V = self.rec_V[:, :, t0 - L : t0]
+ init_rec_K = self.rec_K[:, :, t0 - L : t0]
- if self.training and self.gate_dropout_proba > 0.0:
- # G is NxHxRxT where r is the caterpillar's row.
+ # Here there is a trick: Since the stack at position t is
+ # computed by updating that at position t-L, the parallel
+ # scan operates with a period of L. To do so we split the
+ # sequence indexing in two axes, the second of size L, and
+ # run the parallel scan using the first as the sequence index.
- warnings.warn("gate dropout", RuntimeWarning)
+ A = A.unflatten(2, (-1, L))
+ gated_V = gated_V.unflatten(2, (-1, L))
+ gated_K = gated_K.unflatten(2, (-1, L))
- if self.gate_dropout_sync:
- shape_kill = (N, 1, 1)
- else:
- shape_kill = (N, H, R)
-
- # Pick a point in each of the NxHxR timeline and set this
- # entry and the following to 1
- kill = (
- torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices
- == 0
- ).cumsum(dim=3)
-
- # Keep these mask for only some of the NxHxR
- kill = kill * (
- torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba
- )
-
- # The coefficient to keep are the complementary
- mask = 1 - kill
-
- masked_next_V, masked_next_K = recurrence(G * mask, V, K)
-
- if self.gate_dropout_replace:
- next_V = next_V.detach()
- next_K = next_K.detach()
-
- warnings.warn("the rescaling is probably a bad idea", RuntimeWarning)
-
- next_V = next_V + (masked_next_V - masked_next_V.detach()) / (
- 1 - self.gate_dropout_proba
- )
- next_K = next_K + (masked_next_K - masked_next_K.detach()) / (
- 1 - self.gate_dropout_proba
- )
+ next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3)
+ next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3)
self.rec_V[:, :, t0:t1] = next_V
self.rec_K[:, :, t0:t1] = next_K
windowed_V,
).flatten(2)
- # Compute the final output
-
- # Y = blanket(Y)
-
self.cache_Y[:, t0:t1] = Y @ self.w_O
return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
dim_v,
nb_heads=1,
causal=False,
+ horizon=None,
attention_dropout=0.0,
logger=print,
args=None,
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
self.causal = causal
+ self.horizon = horizon
self.attention_dropout = attention_dropout
self.record_attention = False
torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
< torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
)
+
+ if self.horizon is not None:
+ self.cache_attzero = torch.logical_or(
+ self.cache_attzero,
+ torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
+ >= torch.arange(x_q.size(1), device=q.device)[
+ None, None, None, :
+ ]
+ + self.horizon,
+ )
+
a = a.masked_fill(
self.cache_attzero[
:, :, bs.first : bs.first + bs.nb, : bs.first + bs.nb
"dumbrec",
"kvrec",
"caterpillar",
+ "attcat",
}, f"Unknown attention operator {attention_layer}."
- if attention_layer == "caterpillar":
+ if attention_layer == "caterpillar" or attention_layer == "attcat":
assert nb_lines % caterpillar_height == 0
self.caterpillar_length = nb_lines // caterpillar_height
self.caterpillar_height = caterpillar_height
def attlayer():
if attention_layer == "mha":
- return QKVAttention(
- dim_model=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- causal=causal,
- attention_dropout=dropout,
- logger=logger,
- args=args,
+ return WithResidual(
+ CacheWrapper(nn.LayerNorm((dim_model,))),
+ QKVAttention(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ causal=causal,
+ attention_dropout=dropout,
+ logger=logger,
+ args=args,
+ ),
)
elif attention_layer == "dumbrec":
- return DumbRec(
- dim_model=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- nb_lines=nb_lines,
- attention_dropout=dropout,
- logger=logger,
- args=args,
+ return WithResidual(
+ CacheWrapper(nn.LayerNorm((dim_model,))),
+ DumbRec(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ nb_lines=nb_lines,
+ attention_dropout=dropout,
+ logger=logger,
+ args=args,
+ ),
)
elif attention_layer == "kvrec":
- return KVRec(
- dim_model=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- nb_lines=nb_lines,
- attention_dropout=dropout,
- logger=logger,
- args=args,
+ return WithResidual(
+ CacheWrapper(nn.LayerNorm((dim_model,))),
+ KVRec(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ nb_lines=nb_lines,
+ attention_dropout=dropout,
+ logger=logger,
+ args=args,
+ ),
)
elif attention_layer == "caterpillar":
- return Caterpillar(
- dim_model=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- caterpillar_length=self.caterpillar_length,
- caterpillar_height=self.caterpillar_height,
- attention_dropout=dropout,
- logger=logger,
- args=args,
+ return WithResidual(
+ CacheWrapper(nn.LayerNorm((dim_model,))),
+ Caterpillar(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ caterpillar_length=self.caterpillar_length,
+ caterpillar_height=self.caterpillar_height,
+ attention_dropout=dropout,
+ logger=logger,
+ args=args,
+ ),
+ )
+ elif attention_layer == "attcat":
+ return nn.Sequential(
+ WithResidual(
+ CacheWrapper(nn.LayerNorm((dim_model,))),
+ QKVAttention(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ causal=causal,
+ horizon=self.caterpillar_length,
+ attention_dropout=dropout,
+ logger=logger,
+ args=args,
+ ),
+ ),
+ WithResidual(
+ CacheWrapper(nn.LayerNorm((dim_model,))),
+ Caterpillar(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ caterpillar_length=self.caterpillar_length,
+ caterpillar_height=self.caterpillar_height,
+ attention_dropout=dropout,
+ logger=logger,
+ args=args,
+ ),
+ ),
)
else:
raise ValueError(f"Unknown attention type {attention_layer}.")
for b in range(nb_blocks):
trunk_blocks += [
- WithResidual(
- CacheWrapper(nn.LayerNorm((dim_model,))),
- attlayer(),
- ),
+ attlayer(),
WithResidual(
CacheWrapper(
nn.LayerNorm((dim_model,)),