From 318946b9a800dcc07531053e345bda46440f617f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 23 Jan 2024 22:41:17 +0100 Subject: [PATCH] Update. --- blanket.py | 44 +++++++++++++++++++++++++++++++++++++ mygpt.py | 64 ++++++++++++++++++++++++++++++++---------------------- 2 files changed, 82 insertions(+), 26 deletions(-) create mode 100755 blanket.py diff --git a/blanket.py b/blanket.py new file mode 100755 index 0000000..2b9c896 --- /dev/null +++ b/blanket.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + + +class Blanket(torch.autograd.Function): + @staticmethod + def normalize(x): + y = x.flatten(1) + y /= y.pow(2).sum(dim=1, keepdim=True).sqrt() + 1e-6 + y *= math.sqrt(y.numel() / y.size(0)) + + @staticmethod + def forward(ctx, x): + x = x.clone() + # Normalize the forward + Blanket.normalize(x) + return x + + @staticmethod + def backward(ctx, grad_output): + grad_output = grad_output.clone() + # Normalize the gradient + Blanket.normalize(grad_output) + return grad_output + + +blanket = Blanket.apply + +###################################################################### + +if __name__ == "__main__": + x = torch.rand(2, 3).requires_grad_() + y = blanket(x) * 10 + print(y.pow(2).sum()) + z = y.sin().sum() + g = torch.autograd.grad(z, x)[0] + + print(g.pow(2).sum()) diff --git a/mygpt.py b/mygpt.py index 040845e..9a02bcd 100755 --- a/mygpt.py +++ b/mygpt.py @@ -21,6 +21,8 @@ from torch.nn import functional as F import ffutils +from blanket import blanket + # import memload ###################################################################### @@ -506,11 +508,11 @@ class Caterpillar(nn.Module): ###################################################################### - self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1.0) + self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1e-3) self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0)) self.w_K = randw(nb_heads, dim_qk, dim_model) - self.w_V = randw(nb_heads, dim_v, dim_model, factor=1) + self.w_V = randw(nb_heads, dim_v, dim_model) self.w_Q = randw(nb_heads, dim_qk, dim_model) self.w_O = randw(dim_v * nb_heads, dim_model) @@ -567,6 +569,8 @@ class Caterpillar(nn.Module): V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) + V, K = blanket(V), blanket(K) + ###################################################################### # Compute the recurrent state @@ -583,19 +587,19 @@ class Caterpillar(nn.Module): # Clip the gating to avoid values greater than 1 when several # heads hit the same row - # G = G / G.sum(1, keepdim=True).clamp(min=1) + G = G / G.sum(1, keepdim=True).clamp(min=1) - H = (1 - G).log().sum(1, keepdim=True).exp() + # G_star = (1 - G).log().sum(1, keepdim=True).exp() ###################################################################### def recurrence(G, V, K): # We prepare the arguments for the parallel scan - A = H + A = 1 - G.sum(dim=1) - gated_V = torch.einsum("nhrt,nhtd->nrtd", H * G / (1 - G), V) - gated_K = torch.einsum("nhrt,nhtd->nrtd", H * G / (1 - G), K) + gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) + gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) # We start from cached values, which matters in inference @@ -669,6 +673,8 @@ class Caterpillar(nn.Module): Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q) + Q = blanket(Q) + # We build tensors NxHxTxRxL where N is the sample index, H # the head, T the time, R the row in the caterpillar, and L # the column in the caterpillar @@ -706,6 +712,8 @@ class Caterpillar(nn.Module): # Compute the final output + Y = blanket(Y) + self.cache_Y[:, t0:t1] = Y @ self.w_O return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache) @@ -1081,31 +1089,35 @@ if __name__ == "__main__": # t = np.arange(dt, 20.0, dt) # ax.semilogx(t, np.exp(-t / 5.0)) # ax.grid() + ax.set_yscale("log") ###################################################################### - for label, model in [ - # ("nn.Linear", linear), - ("mygpy.QKVAttention", qkv), - ("mygpt.Caterpillar", caterpillar), + for label, model, thickness in [ + ("nn.Linear", linear, 0.2), + ("mygpy.QKVAttention", qkv, 1), + ("mygpt.Caterpillar", caterpillar, 2), ]: y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x - data = [] - for t in range(y.size(1)): - for d in torch.randperm(y.size(2))[:8]: - g = torch.autograd.grad(y[0, t, d], x, retain_graph=True)[0] - sg = g.pow(2).sum().item() - # sg = 0 - # for p in model.parameters(): - # g = torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0] - # sg = sg + g.pow(2).sum().item() - data.append([t, sg]) - - data = torch.tensor(data) - ax.scatter( - data[:, 0], data[:, 1], s=1, label=label - ) # , color='gray', label='Input') + for n, p in [("input", x)] + list(model.named_parameters()): + print(f"Processing {model}.{n}") + data = [] + for t in range(y.size(1)): + sg = 0 + for d in torch.randperm(y.size(2))[:8]: + sg += torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0] + assert not sg.isinf().any() + assert not sg.isnan().any() + data.append([t, sg.sum().item()]) + + data = torch.tensor(data) + # cx, cy = data[:, 0], data[:, 1] + cy = data[:, 1].sort().values + cx = torch.linspace(0, 1, cy.size(0)) + ax.plot( + cx, cy, label=label + "." + n, linewidth=thickness + ) # , color='gray', label='Input') # ax.legend(frameon=False, loc="top right") -- 2.39.5