X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=fridge;h=a4d860b73ac2f720363299030a75611f33110871;hb=HEAD;hp=82d2b17a9d3235917f7047d5c71f4f759a027fba;hpb=3d7db5b3c1304fdbd599c2a001b5c31df4df2599;p=mygptrnn.git diff --git a/fridge b/fridge index 82d2b17..a4d860b 100644 --- a/fridge +++ b/fridge @@ -316,3 +316,92 @@ class Calibrator: if isinstance(m, mygpt.Caterpillar): + +###################################################################### + +2024 Feb 13 22:53:52 (from mygpt.py) + + ###################################################################### + # Prepare the keys + + k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1) + + warnings.warn("rotating key barrel", RuntimeWarning) + k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1) + t_barrel = torch.arange(t0, t1, device=k_star.device) + t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0) + l_barrel = ( + torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel + ) % k_star.size(0) + k_star = k_star[l_barrel, t_barrel] + + +###################################################################### + +2024 Feb 15 23:10:50 (from main.py) + + +def add_memex_v4(batches, memex_proba, marker_token): + for input in batches: + if torch.rand(1).item() < memex_proba: + t = ( + torch.arange(2 * input.size(1), device=input.device)[None, :] + .expand(input.size(0), -1) + .clone() + ) + + u = torch.rand(t.size(), device=t.device) + u[:, : input.size(1)] = 1.0 + memex_v3_proba_fragment = 1 / 20 + u = (u < memex_v3_proba_fragment).long() + v = u * torch.randint(input.size(1), u.size()) + u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[ + :, : input.size(1) - 1 + ] * input.size(1) + u = u.cumsum().clamp(min=0) + + u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device) + caterpillar_length = args.nb_lines // args.caterpillar_height + u1 = ( + u0 + + torch.randint( + caterpillar_length, (input.size(0), 1), device=input.device + ) + + 1 + ) + + m0 = (t < u0).long() + m1 = (t >= u1).long() * (t < u1 + input.size(1)).long() + + t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1 + m = (t < 0).long() + n = torch.arange(input.size(0), device=input.device)[:, None].expand( + -1, t.size(1) + ) + + new_input = input[n, t.clamp(min=0)] + new_input = (1 - m) * new_input + m * (marker_token) + + yield new_input + + yield input + + + +###################################################################### + +2024 Feb 16 17:07:48 (from main.py) + + # ||gn + lambda * gm|| = max(||gn||,||gm||) + # ||gn||^2 + lambda + lambda^2||gm||^2 = max(||gn||^2,||gm||^2) + # A = ||gm||^2 B = C = ||gn||^2 - max(||gn||^2, ||gm||^2) + +###################################################################### + +2024 Feb 16 17:07:51 (from main.py) + + # A,B,C = gmgm, gngm, gngn - max(gngn,gmgm) + # Delta = B*B - 4*A*C + # if(delta >= 0): + # l = ( -B - sqrt(Delta))/(2*A) + # ||gn||+l*rho*||gm|| = max(||gn||,rho*||gm||)