X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=040845ede9e4307f6e76c8a4a7faadd5bacd9974;hb=3d7db5b3c1304fdbd599c2a001b5c31df4df2599;hp=4d4824707baf1f8d22d961f8331cd8fe11cb510c;hpb=c0750e416e28fbdc9f6dc03cc6d7b11edd1ac333;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index 4d48247..040845e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -10,6 +10,8 @@ # with a caching mechanism for keys and values to avoid a O(N^3) cost # for auto-regression. +# This implementation is equipped with RNN layers to replace the MHA + import math, warnings import torch, einops @@ -37,7 +39,7 @@ import ffutils # 1 for the successive tokens. # # Modules able to process brackets may implement a cache that is -# resetted when the input bracket starts at t=0 +# resetted when init_cache is True class BracketedSequence: @@ -124,7 +126,6 @@ class AddPositionalEncoding(nn.Module): import pscan - # X is /.../xTxD A is /.../xT Y_init is /.../xD @@ -145,6 +146,18 @@ def pscan_dim(A, X, Y_init, dim=-2): return Y +def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2): + with torch.no_grad(): + s_A, s_X = 0, 0 + for t in range(X.size(dim) - 1, 0, -1): + delta = (grad_Y[t] - s_A) / A[t].grad + s_A += A[t].grad * delta + A[t].grad = delta + delta = (grad_Y[t] - s_X) / X[t].grad + s_X += X[t].grad * delta + X[t].grad = delta + + def pscan_shape(A, X, Y_init): s = X.size() A = A.reshape(-1, s[-2]) @@ -188,6 +201,8 @@ class DumbRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + args=None, ): super().__init__() @@ -317,6 +332,8 @@ class KVRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + args=None, ): super().__init__() @@ -469,41 +486,53 @@ class Caterpillar(nn.Module): caterpillar_height, attention_dropout=0.0, len_max=1e5, + logger=print, + args=None, ): super().__init__() warnings.warn("Caterpillar", RuntimeWarning) - def randw(*d): - return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + def randw(*d, factor=1): + return nn.Parameter(torch.randn(*d) * factor / math.sqrt(d[-1])) self.caterpillar_length = caterpillar_length self.caterpillar_height = caterpillar_height self.attention_dropout = attention_dropout - self.w_G = randw(nb_heads, caterpillar_height, dim_model) - self.b_G = nn.Parameter( - torch.full( - (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1) - ) - ) + self.gate_dropout_proba = args.gate_dropout_proba + self.gate_dropout_sync = args.gate_dropout_sync + self.gate_dropout_replace = args.gate_dropout_replace + + ###################################################################### + + self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1.0) + self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0)) self.w_K = randw(nb_heads, dim_qk, dim_model) - self.w_V = randw(nb_heads, dim_v, dim_model) + self.w_V = randw(nb_heads, dim_v, dim_model, factor=1) self.w_Q = randw(nb_heads, dim_qk, dim_model) self.w_O = randw(dim_v * nb_heads, dim_model) - self.init_K_rec = randw(caterpillar_height, caterpillar_length, dim_qk) - self.init_V_rec = randw(caterpillar_height, caterpillar_length, dim_v) + self.init_K_rec = randw( + caterpillar_height, + caterpillar_length, + dim_qk, + ) + self.init_V_rec = randw( + caterpillar_height, + caterpillar_length, + dim_v, + ) - def reset_inner_loss(self): - self.acc_attention = 0 - self.acc_nb = 0 + # def reset_inner_loss(self): + # self.acc_attention = 0 + # self.acc_nb = 0 - def get_inner_loss(self): - # warnings.warn("l2 regularization", RuntimeWarning) - # return (self.acc_attention / self.acc_nb).pow(2).sum() - return torch.tensor([0], device=self.w_Q.device) + # def get_inner_loss(self): + # warnings.warn("l2 regularization", RuntimeWarning) + # return (self.acc_attention / self.acc_nb).pow(2).sum() + # return torch.tensor([0], device=self.w_Q.device) def forward(self, bs): # Dimensions to make the source a bit clearer, that's needed @@ -512,92 +541,150 @@ class Caterpillar(nn.Module): N = bs.x.size(0) T = bs.x.size(1) + H = self.w_V.size(0) DV = self.w_V.size(1) DK = self.w_K.size(1) - Dout = self.w_O.size(1) - CH = self.caterpillar_height - CL = self.caterpillar_length + DM = self.w_O.size(1) + R = self.caterpillar_height + L = self.caterpillar_length assert ( - t0 >= CL and (t1 - t0) % CL == 0 + t0 >= L and (t1 - t0) % L == 0 ), f"bs.first should be greater than caterpillar_length, and bs.nb should be a multiple of caterpillar_length" + # We cache values to deal efficiently with auto-regression + if bs.init_cache: - self.rec_V = X.new_zeros(N, CH, T, DV) - self.rec_K = X.new_zeros(N, CH, T, DK) + self.rec_V = X.new_zeros(N, R, T, DV) + self.rec_K = X.new_zeros(N, R, T, DK) # We start the recurrent sequences with optimizable # initial values. No idea if it helps. - self.rec_V[:, :, t0 - CL : t0] = self.init_V_rec[None, :, :, :] - self.rec_K[:, :, t0 - CL : t0] = self.init_K_rec[None, :, :, :] + self.rec_V[:, :, t0 - L : t0, :] = self.init_V_rec[None, :, :, :] + self.rec_K[:, :, t0 - L : t0, :] = self.init_K_rec[None, :, :, :] - self.cache_Y = X.new_zeros(N, T, Dout) + self.cache_Y = X.new_zeros(N, T, DM) + + V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) + K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) ###################################################################### # Compute the recurrent state # This is the Gating sequence that modulates the storing of - # the new key and value in the CH pairs of the current - # stack. The CH gating values are independent, which means - # that the current K/V could be stored in all the pairs of the + # the new key and value in the R pairs of the current + # stack. There are R independent gating values, which means + # that the current K/V may be stored in multiple pairs of the # recurrent state, or not at all. G = ( - torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None] + torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() - G = F.dropout(G, self.attention_dropout, self.training) + # Clip the gating to avoid values greater than 1 when several + # heads hit the same row - V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) - K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) + # G = G / G.sum(1, keepdim=True).clamp(min=1) + + H = (1 - G).log().sum(1, keepdim=True).exp() + + ###################################################################### + + def recurrence(G, V, K): + # We prepare the arguments for the parallel scan + + A = H + + gated_V = torch.einsum("nhrt,nhtd->nrtd", H * G / (1 - G), V) + gated_K = torch.einsum("nhrt,nhtd->nrtd", H * G / (1 - G), K) + + # We start from cached values, which matters in inference + + init_rec_V = self.rec_V[:, :, t0 - L : t0] + init_rec_K = self.rec_K[:, :, t0 - L : t0] - # We prepare the arguments for the parallel scan + # Here there is a trick: Since the stack at position t is + # computed by updating that at position t-L, the parallel + # scan operates with a period of L. To do so we split the + # sequence indexing in two axes, the second of size L, and + # run the parallel scan using the first as the sequence index. - A = 1 - G.sum(1) - gated_V = torch.einsum("nhet,nhtd->netd", G, V) - gated_K = torch.einsum("nhet,nhtd->netd", G, K) + A = A.unflatten(2, (-1, L)) + gated_V = gated_V.unflatten(2, (-1, L)) + gated_K = gated_K.unflatten(2, (-1, L)) - init_rec_V = self.rec_V[:, :, t0 - CL : t0] - init_rec_K = self.rec_K[:, :, t0 - CL : t0] + next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3) + next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3) - # Here there is a trick: Since the stack at time t is computed - # by updating that at time t-L, the parallel scan operates - # with a period of L. To do so we split the time indexing in - # two axes, the second of size CL, and run the parallel scan - # using the other as the sequence index. + return next_V, next_K - A = A.unflatten(2, (-1, CL)) - gated_V = gated_V.unflatten(2, (-1, CL)) - gated_K = gated_K.unflatten(2, (-1, CL)) + ################################################################# - next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) - next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) + next_V, next_K = recurrence(G, V, K) - # Put back the sequence index + if self.training and self.gate_dropout_proba > 0.0: + # G is NxHxRxT where r is the caterpillar's row. - self.rec_V[:, :, t0:t1] = next_V.flatten(2, 3) - self.rec_K[:, :, t0:t1] = next_K.flatten(2, 3) + warnings.warn("gate dropout", RuntimeWarning) + + if self.gate_dropout_sync: + shape_kill = (N, 1, 1) + else: + shape_kill = (N, H, R) + + # Pick a point in each of the NxHxR timeline and set this + # entry and the following to 1 + kill = ( + torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices + == 0 + ).cumsum(dim=3) + + # Keep these mask for only some of the NxHxR + kill = kill * ( + torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba + ) + + # The coefficient to keep are the complementary + mask = 1 - kill + + masked_next_V, masked_next_K = recurrence(G * mask, V, K) + + if self.gate_dropout_replace: + next_V = next_V.detach() + next_K = next_K.detach() + + warnings.warn("the rescaling is probably a bad idea", RuntimeWarning) + + next_V = next_V + (masked_next_V - masked_next_V.detach()) / ( + 1 - self.gate_dropout_proba + ) + next_K = next_K + (masked_next_K - masked_next_K.detach()) / ( + 1 - self.gate_dropout_proba + ) + + self.rec_V[:, :, t0:t1] = next_V + self.rec_K[:, :, t0:t1] = next_K ###################################################################### # compute the readout Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q) - # We build tensors NxHxTxFxL where N is the sample index, H - # the head, T the time, F the row in the caterpillar, and L + # We build tensors NxHxTxRxL where N is the sample index, H + # the head, T the time, R the row in the caterpillar, and L # the column in the caterpillar windowed_V = moving_window( - self.rec_V[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL + self.rec_V[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L ) windowed_K = moving_window( - self.rec_K[:, :, t0 - CL + 1 : t1], dim=2, win_dim=3, win_size=CL + self.rec_K[:, :, t0 - L + 1 : t1], dim=2, win_dim=3, win_size=L ) - # We have an attention score for each of the CHxCL values + # We have an attention score for each of the RxL values ar = torch.einsum( - "nhtd,nftld->nhtfl", + "nhtd,nrtld->nhtrl", Q, windowed_K, ) / math.sqrt(DK) @@ -636,6 +723,8 @@ class QKVAttention(nn.Module): nb_heads=1, causal=False, attention_dropout=0.0, + logger=print, + args=None, ): super().__init__() @@ -723,15 +812,21 @@ class MyGPT(nn.Module): nb_blocks, nb_lines=None, caterpillar_height=None, - dim_rec_v=-1, causal=False, dropout=0.0, len_max=1e5, - attention_layer="kvrec", + attention_layer="caterpillar", + logger=print, + args=None, ): super().__init__() - assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"} + assert attention_layer in { + "mha", + "dumbrec", + "kvrec", + "caterpillar", + }, f"Unknown attention operator {attention_layer}." if attention_layer == "caterpillar": assert nb_lines % caterpillar_height == 0 @@ -759,34 +854,42 @@ class MyGPT(nn.Module): nb_heads=nb_heads, causal=causal, attention_dropout=dropout, + logger=logger, + args=args, ) elif attention_layer == "dumbrec": return DumbRec( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + args=args, ) elif attention_layer == "kvrec": return KVRec( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + args=args, ) elif attention_layer == "caterpillar": return Caterpillar( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, caterpillar_length=self.caterpillar_length, caterpillar_height=self.caterpillar_height, attention_dropout=dropout, + logger=logger, + args=args, ) else: raise ValueError(f"Unknown attention type {attention_layer}.") @@ -917,7 +1020,111 @@ class MyGPT(nn.Module): ###################################################################### if __name__ == "__main__": - print("Basic check.") + import argparse + + import numpy as np + import matplotlib.pyplot as plt + import matplotlib.collections as mc + + args = argparse.Namespace( + gate_dropout_proba=0.0, gate_dropout_sync=True, gate_dropout_replace=False + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + dim_model, dim_keys, nb_heads = 512, 64, 1 + dropout = 0.1 + + caterpillar = Caterpillar( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + caterpillar_length=16, + caterpillar_height=32, + attention_dropout=dropout, + args=args, + ).to(device) + + qkv = QKVAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + causal=True, + attention_dropout=dropout, + args=args, + ).to(device) + + linear = CacheWrapper(nn.Linear(512, 512)).to(device) + + x = torch.randn(1, 256, dim_model) + + x = x.to(device) + x.requires_grad_() + + ###################################################################### + + fig = plt.figure() + fig.set_figheight(6) + fig.set_figwidth(8) + + ax = fig.add_subplot(1, 1, 1) + + # ax.set_xlim(-1.5, 1.5) + # ax.set_ylim(-1.5, 1.5) + # ax.set(aspect=1) + # ax.spines.right.set_visible(False) + # ax.spines.top.set_visible(False) + + # dt = 0.01 + # t = np.arange(dt, 20.0, dt) + # ax.semilogx(t, np.exp(-t / 5.0)) + # ax.grid() + + ###################################################################### + + for label, model in [ + # ("nn.Linear", linear), + ("mygpy.QKVAttention", qkv), + ("mygpt.Caterpillar", caterpillar), + ]: + y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x + + data = [] + for t in range(y.size(1)): + for d in torch.randperm(y.size(2))[:8]: + g = torch.autograd.grad(y[0, t, d], x, retain_graph=True)[0] + sg = g.pow(2).sum().item() + # sg = 0 + # for p in model.parameters(): + # g = torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0] + # sg = sg + g.pow(2).sum().item() + data.append([t, sg]) + + data = torch.tensor(data) + ax.scatter( + data[:, 0], data[:, 1], s=1, label=label + ) # , color='gray', label='Input') + + # ax.legend(frameon=False, loc="top right") + + # Put a legend to the right of the current axis + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + + filename = "plot.pdf" + print(f"saving {filename}") + fig.savefig(filename, bbox_inches="tight") + + # if args.window and hasattr(plt.get_current_fig_manager(), 'window'): + # plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) + # plt.show() + + exit(0) + + ###################################################################### m = Caterpillar( dim_model=4, @@ -939,8 +1146,6 @@ if __name__ == "__main__": print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max()) exit(0) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - vocabulary_size = 128 x = torch.randint(vocabulary_size, (6, 1024))