import ffutils
+from blanket import blanket
+
# import memload
######################################################################
######################################################################
- self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1.0)
+ self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1e-3)
self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0))
self.w_K = randw(nb_heads, dim_qk, dim_model)
- self.w_V = randw(nb_heads, dim_v, dim_model, factor=1)
+ self.w_V = randw(nb_heads, dim_v, dim_model)
self.w_Q = randw(nb_heads, dim_qk, dim_model)
self.w_O = randw(dim_v * nb_heads, dim_model)
V = torch.einsum("ntc,hdc->nhtd", X, self.w_V)
K = torch.einsum("ntc,hdc->nhtd", X, self.w_K)
+ V, K = blanket(V), blanket(K)
+
######################################################################
# Compute the recurrent state
# Clip the gating to avoid values greater than 1 when several
# heads hit the same row
- # G = G / G.sum(1, keepdim=True).clamp(min=1)
+ G = G / G.sum(1, keepdim=True).clamp(min=1)
- H = (1 - G).log().sum(1, keepdim=True).exp()
+ # G_star = (1 - G).log().sum(1, keepdim=True).exp()
######################################################################
def recurrence(G, V, K):
# We prepare the arguments for the parallel scan
- A = H
+ A = 1 - G.sum(dim=1)
- gated_V = torch.einsum("nhrt,nhtd->nrtd", H * G / (1 - G), V)
- gated_K = torch.einsum("nhrt,nhtd->nrtd", H * G / (1 - G), K)
+ gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V)
+ gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K)
# We start from cached values, which matters in inference
Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q)
+ Q = blanket(Q)
+
# We build tensors NxHxTxRxL where N is the sample index, H
# the head, T the time, R the row in the caterpillar, and L
# the column in the caterpillar
# Compute the final output
+ Y = blanket(Y)
+
self.cache_Y[:, t0:t1] = Y @ self.w_O
return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache)
# t = np.arange(dt, 20.0, dt)
# ax.semilogx(t, np.exp(-t / 5.0))
# ax.grid()
+ ax.set_yscale("log")
######################################################################
- for label, model in [
- # ("nn.Linear", linear),
- ("mygpy.QKVAttention", qkv),
- ("mygpt.Caterpillar", caterpillar),
+ for label, model, thickness in [
+ ("nn.Linear", linear, 0.2),
+ ("mygpy.QKVAttention", qkv, 1),
+ ("mygpt.Caterpillar", caterpillar, 2),
]:
y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x
- data = []
- for t in range(y.size(1)):
- for d in torch.randperm(y.size(2))[:8]:
- g = torch.autograd.grad(y[0, t, d], x, retain_graph=True)[0]
- sg = g.pow(2).sum().item()
- # sg = 0
- # for p in model.parameters():
- # g = torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0]
- # sg = sg + g.pow(2).sum().item()
- data.append([t, sg])
-
- data = torch.tensor(data)
- ax.scatter(
- data[:, 0], data[:, 1], s=1, label=label
- ) # , color='gray', label='Input')
+ for n, p in [("input", x)] + list(model.named_parameters()):
+ print(f"Processing {model}.{n}")
+ data = []
+ for t in range(y.size(1)):
+ sg = 0
+ for d in torch.randperm(y.size(2))[:8]:
+ sg += torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0]
+ assert not sg.isinf().any()
+ assert not sg.isnan().any()
+ data.append([t, sg.sum().item()])
+
+ data = torch.tensor(data)
+ # cx, cy = data[:, 0], data[:, 1]
+ cy = data[:, 1].sort().values
+ cx = torch.linspace(0, 1, cy.size(0))
+ ax.plot(
+ cx, cy, label=label + "." + n, linewidth=thickness
+ ) # , color='gray', label='Input')
# ax.legend(frameon=False, loc="top right")