X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=67c5cfd96ff11a5fd04b88eb6b26a72c90e97ddb;hb=6d23462ce76c9020dcd7c4bc8a0e7a0fae9b7971;hp=54515841259c63265ed143f7abf5e529db6763bc;hpb=de99e48d5c2dfb72e811f0bb1c2c09aa154af8b6;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index 5451584..67c5cfd 100755 --- a/mygpt.py +++ b/mygpt.py @@ -21,6 +21,8 @@ from torch.nn import functional as F import ffutils +# from blanket import blanket + # import memload ###################################################################### @@ -202,7 +204,7 @@ class DumbRec(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - **kwargs, + args=None, ): super().__init__() @@ -333,7 +335,7 @@ class KVRec(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - **kwargs, + args=None, ): super().__init__() @@ -487,44 +489,27 @@ class Caterpillar(nn.Module): attention_dropout=0.0, len_max=1e5, logger=print, - **kwargs, + args=None, ): super().__init__() warnings.warn("Caterpillar", RuntimeWarning) - def randw(*d, amplitude=None): - if amplitude is None: - amplitude = 1 / math.sqrt(d[-1]) - return nn.Parameter(amplitude * torch.randn(*d)) + def randw(*d, factor=1): + return nn.Parameter(torch.randn(*d) * factor / math.sqrt(d[-1])) self.caterpillar_length = caterpillar_length self.caterpillar_height = caterpillar_height self.attention_dropout = attention_dropout - ###################################################################### - # sup_args - - x = kwargs.get("gate_dropout") - if x is None: - self.proba_gate_dropout = 0.0 - else: - self.proba_gate_dropout = float(x) - - logger(f"self.proba_gate_dropout {self.proba_gate_dropout}") - - x = kwargs.get("default_bg") - if x is None: - default_bg = -math.log(caterpillar_height - 1) - else: - default_bg = float(x) - - logger(f"default_bg {default_bg}") + self.gate_dropout_proba = args.gate_dropout_proba + self.gate_dropout_sync = args.gate_dropout_sync + self.gate_dropout_replace = args.gate_dropout_replace ###################################################################### - self.w_G = randw(nb_heads, caterpillar_height, dim_model) - self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_bg)) + self.w_G = randw(nb_heads, caterpillar_height, dim_model, factor=1e-3) + self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), 0.0)) self.w_K = randw(nb_heads, dim_qk, dim_model) self.w_V = randw(nb_heads, dim_v, dim_model) @@ -542,14 +527,14 @@ class Caterpillar(nn.Module): dim_v, ) - def reset_inner_loss(self): - self.acc_attention = 0 - self.acc_nb = 0 + # def reset_inner_loss(self): + # self.acc_attention = 0 + # self.acc_nb = 0 - def get_inner_loss(self): - # warnings.warn("l2 regularization", RuntimeWarning) - # return (self.acc_attention / self.acc_nb).pow(2).sum() - return torch.tensor([0], device=self.w_Q.device) + # def get_inner_loss(self): + # warnings.warn("l2 regularization", RuntimeWarning) + # return (self.acc_attention / self.acc_nb).pow(2).sum() + # return torch.tensor([0], device=self.w_Q.device) def forward(self, bs): # Dimensions to make the source a bit clearer, that's needed @@ -584,6 +569,8 @@ class Caterpillar(nn.Module): V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) + # V, K = blanket(V), blanket(K) + ###################################################################### # Compute the recurrent state @@ -602,12 +589,14 @@ class Caterpillar(nn.Module): G = G / G.sum(1, keepdim=True).clamp(min=1) + # G_star = (1 - G).log().sum(1, keepdim=True).exp() + ###################################################################### def recurrence(G, V, K): # We prepare the arguments for the parallel scan - A = 1 - G.sum(1) + A = 1 - G.sum(dim=1) gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) @@ -627,11 +616,8 @@ class Caterpillar(nn.Module): gated_V = gated_V.unflatten(2, (-1, L)) gated_K = gated_K.unflatten(2, (-1, L)) - next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) - next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) - - next_V = next_V.flatten(2, 3) - next_K = next_K.flatten(2, 3) + next_V = pscan_dim(A, gated_V, init_rec_V, dim=2).flatten(2, 3) + next_K = pscan_dim(A, gated_K, init_rec_K, dim=2).flatten(2, 3) return next_V, next_K @@ -639,31 +625,44 @@ class Caterpillar(nn.Module): next_V, next_K = recurrence(G, V, K) - if self.training and self.proba_gate_dropout > 0.0: + if self.training and self.gate_dropout_proba > 0.0: # G is NxHxRxT where r is the caterpillar's row. warnings.warn("gate dropout", RuntimeWarning) - # kill = ( - # torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout - # ).float() + if self.gate_dropout_sync: + shape_kill = (N, 1, 1) + else: + shape_kill = (N, H, R) + # Pick a point in each of the NxHxR timeline and set this + # entry and the following to 1 kill = ( - torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0 + torch.rand(*shape_kill, t1 - t0, device=G.device).sort(dim=3).indices + == 0 ).cumsum(dim=3) + + # Keep these mask for only some of the NxHxR kill = kill * ( - torch.rand(N, H, R, 1, device=G.device) <= self.proba_gate_dropout + torch.rand(*shape_kill, 1, device=G.device) <= self.gate_dropout_proba ) + # The coefficient to keep are the complementary mask = 1 - kill masked_next_V, masked_next_K = recurrence(G * mask, V, K) - next_V = next_V.detach() + (masked_next_V - masked_next_V.detach()) / ( - 1 - self.proba_gate_dropout + if self.gate_dropout_replace: + next_V = next_V.detach() + next_K = next_K.detach() + + warnings.warn("the rescaling is probably a bad idea", RuntimeWarning) + + next_V = next_V + (masked_next_V - masked_next_V.detach()) / ( + 1 - self.gate_dropout_proba ) - next_K = next_K.detach() + (masked_next_K - masked_next_K.detach()) / ( - 1 - self.proba_gate_dropout + next_K = next_K + (masked_next_K - masked_next_K.detach()) / ( + 1 - self.gate_dropout_proba ) self.rec_V[:, :, t0:t1] = next_V @@ -674,8 +673,10 @@ class Caterpillar(nn.Module): Q = torch.einsum("ntc,hdc->nhtd", X, self.w_Q) - # We build tensors NxHxTxFxL where N is the sample index, H - # the head, T the time, F the row in the caterpillar, and L + # Q = blanket(Q) + + # We build tensors NxHxTxRxL where N is the sample index, H + # the head, T the time, R the row in the caterpillar, and L # the column in the caterpillar windowed_V = moving_window( @@ -689,7 +690,7 @@ class Caterpillar(nn.Module): # We have an attention score for each of the RxL values ar = torch.einsum( - "nhtd,nftld->nhtfl", + "nhtd,nrtld->nhtrl", Q, windowed_K, ) / math.sqrt(DK) @@ -711,6 +712,8 @@ class Caterpillar(nn.Module): # Compute the final output + # Y = blanket(Y) + self.cache_Y[:, t0:t1] = Y @ self.w_O return BracketedSequence(self.cache_Y, t0, t1 - t0, bs.init_cache) @@ -729,7 +732,7 @@ class QKVAttention(nn.Module): causal=False, attention_dropout=0.0, logger=print, - **kwargs, + args=None, ): super().__init__() @@ -820,9 +823,9 @@ class MyGPT(nn.Module): causal=False, dropout=0.0, len_max=1e5, - attention_layer="kvrec", + attention_layer="caterpillar", logger=print, - **kwargs, + args=None, ): super().__init__() @@ -860,7 +863,7 @@ class MyGPT(nn.Module): causal=causal, attention_dropout=dropout, logger=logger, - **kwargs, + args=args, ) elif attention_layer == "dumbrec": return DumbRec( @@ -871,7 +874,7 @@ class MyGPT(nn.Module): nb_lines=nb_lines, attention_dropout=dropout, logger=logger, - **kwargs, + args=args, ) elif attention_layer == "kvrec": return KVRec( @@ -882,7 +885,7 @@ class MyGPT(nn.Module): nb_lines=nb_lines, attention_dropout=dropout, logger=logger, - **kwargs, + args=args, ) elif attention_layer == "caterpillar": return Caterpillar( @@ -894,7 +897,7 @@ class MyGPT(nn.Module): caterpillar_height=self.caterpillar_height, attention_dropout=dropout, logger=logger, - **kwargs, + args=args, ) else: raise ValueError(f"Unknown attention type {attention_layer}.") @@ -1025,7 +1028,115 @@ class MyGPT(nn.Module): ###################################################################### if __name__ == "__main__": - print("Basic check.") + import argparse + + import numpy as np + import matplotlib.pyplot as plt + import matplotlib.collections as mc + + args = argparse.Namespace( + gate_dropout_proba=0.0, gate_dropout_sync=True, gate_dropout_replace=False + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + dim_model, dim_keys, nb_heads = 512, 64, 1 + dropout = 0.1 + + caterpillar = Caterpillar( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + caterpillar_length=16, + caterpillar_height=32, + attention_dropout=dropout, + args=args, + ).to(device) + + qkv = QKVAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + causal=True, + attention_dropout=dropout, + args=args, + ).to(device) + + linear = CacheWrapper(nn.Linear(512, 512)).to(device) + + x = torch.randn(1, 256, dim_model) + + x = x.to(device) + x.requires_grad_() + + ###################################################################### + + fig = plt.figure() + fig.set_figheight(6) + fig.set_figwidth(8) + + ax = fig.add_subplot(1, 1, 1) + + # ax.set_xlim(-1.5, 1.5) + # ax.set_ylim(-1.5, 1.5) + # ax.set(aspect=1) + # ax.spines.right.set_visible(False) + # ax.spines.top.set_visible(False) + + # dt = 0.01 + # t = np.arange(dt, 20.0, dt) + # ax.semilogx(t, np.exp(-t / 5.0)) + # ax.grid() + ax.set_yscale("log") + + ###################################################################### + + for label, model, thickness in [ + ("nn.Linear", linear, 0.2), + ("mygpy.QKVAttention", qkv, 1), + ("mygpt.Caterpillar", caterpillar, 2), + ]: + y = model(BracketedSequence(x, 32, x.size(1) - 32, init_cache=True)).x + + for n, p in [("input", x)] + list(model.named_parameters()): + print(f"Processing {model}.{n}") + data = [] + for t in range(y.size(1)): + sg = 0 + for d in torch.randperm(y.size(2))[:8]: + sg += torch.autograd.grad(y[0, t, d], p, retain_graph=True)[0] + assert not sg.isinf().any() + assert not sg.isnan().any() + data.append([t, sg.sum().item()]) + + data = torch.tensor(data) + # cx, cy = data[:, 0], data[:, 1] + cy = data[:, 1].sort().values + cx = torch.linspace(0, 1, cy.size(0)) + ax.plot( + cx, cy, label=label + "." + n, linewidth=thickness + ) # , color='gray', label='Input') + + # ax.legend(frameon=False, loc="top right") + + # Put a legend to the right of the current axis + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + + filename = "plot.pdf" + print(f"saving {filename}") + fig.savefig(filename, bbox_inches="tight") + + # if args.window and hasattr(plt.get_current_fig_manager(), 'window'): + # plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) + # plt.show() + + exit(0) + + ###################################################################### m = Caterpillar( dim_model=4, @@ -1047,8 +1158,6 @@ if __name__ == "__main__": print((y1 - torch.cat([y3a, y3b], dim=1)).abs().max()) exit(0) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - vocabulary_size = 128 x = torch.randint(vocabulary_size, (6, 1024))