From 73acbc986f9c386c001117581c4fc72d2f36803a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 13 Jan 2024 19:39:23 +0100 Subject: [PATCH] Update. --- fridge | 11 +++++++++++ main.py | 26 +++++++++++++++++++++++--- mygpt.py | 50 +++++++++++++++++++++++++++++++++++--------------- 3 files changed, 69 insertions(+), 18 deletions(-) diff --git a/fridge b/fridge index dcaac19..194c4e6 100644 --- a/fridge +++ b/fridge @@ -166,3 +166,14 @@ def insert_flash_back(rec_V, V, rec_K, K, t0, t1, CL, proba): + (1 - mask) * self.rec_K[:, :, t0:t1] ) + +###################################################################### + +2024 Jan 13 13:38:31 (from mygpt.py) + + g= F.sigmoid(self.b_G) + a=1-g + + print(f"\n\nSANITY {a**T}\n") + exit(0) + diff --git a/main.py b/main.py index 969b47f..c22ae57 100755 --- a/main.py +++ b/main.py @@ -202,9 +202,11 @@ parser.add_argument("--mixing_deterministic_start", action="store_true", default ###################################################################### -args = parser.parse_args() +# args = parser.parse_args() -assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} +args, sup_args = parser.parse_known_args() + +sup_args = dict([x.removeprefix("--").split("=") for x in sup_args]) if args.result_dir is None: args.result_dir = f"results_{args.task}_{args.model}" @@ -432,6 +434,8 @@ except FileExistsError: print(f"result directory {args.result_dir} already exists") exit(1) +loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a") + log_file = open(os.path.join(args.result_dir, args.log_filename), "a") if args.seed >= 0: @@ -461,13 +465,16 @@ with os.popen("sha256sum *.py") as f: log_string(f"sha256sum {l.strip()}") now = time.strftime("%Y%m%d-%H%M%S", time.localtime()) -os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh") +os.system(f"tar --ignore-failed-read zcvf {args.result_dir}/src-{now}.tgz *.py *.sh") log_string(f"argv {' '.join(sys.argv)}") for n in vars(args): log_string(f"args.{n} {getattr(args, n)}") +for n in vars(sup_args): + log_string(f"sup_args.{n} {getattr(sup_args, n)}") + ###################################################################### @@ -505,6 +512,9 @@ def get_lr(n_epoch, it): ###################################################################### +assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} + + def picoclvr_pruner_horizontal_green(p): return not ("green" in p and ("left" in p or "right" in p)) @@ -730,6 +740,8 @@ model = mygpt.MyGPT( causal=True, dropout=args.dropout, attention_layer=args.attention, + logger=log_string, + **sup_args, ) model.to(device) @@ -838,6 +850,8 @@ time_pred_result = datetime.datetime.now() it = 0 +n_batch = 0 + for n_epoch in range(nb_epochs_finished, nb_epochs): if args.optim == "sgd": optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) @@ -879,6 +893,12 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): total_loss.backward() optimizer.step() + grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt() + + loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n") + + n_batch += 1 + with torch.autograd.no_grad(): model.eval() diff --git a/mygpt.py b/mygpt.py index a62cf49..7c9991f 100755 --- a/mygpt.py +++ b/mygpt.py @@ -190,6 +190,8 @@ class DumbRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -319,6 +321,8 @@ class KVRec(nn.Module): nb_lines, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -471,6 +475,8 @@ class Caterpillar(nn.Module): caterpillar_height, attention_dropout=0.0, len_max=1e5, + logger=print, + **kwargs, ): super().__init__() @@ -487,12 +493,14 @@ class Caterpillar(nn.Module): self.proba_gate_dropout = 0.0 + default_b_G = kwargs.get("default_b_G") + if default_b_G is None: + default_b_G = -math.log(caterpillar_height - 1) + + logger(f"default_b_G {default_b_G}") + self.w_G = randw(nb_heads, caterpillar_height, dim_model) - self.b_G = nn.Parameter( - torch.full( - (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1) - ) - ) + self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_b_G)) self.w_K = randw(nb_heads, dim_qk, dim_model) self.w_V = randw(nb_heads, dim_v, dim_model) @@ -565,15 +573,20 @@ class Caterpillar(nn.Module): torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() + # Clip the gating to avoid values greater than 1 when several + # heads hit the same row + + G = G / G.sum(1, keepdim=True).clamp(min=1) + ###################################################################### # Roll the gating indexes - warnings.warn("rotating barrel", RuntimeWarning) + # warnings.warn("rotating barrel", RuntimeWarning) - r_barrel = torch.arange(R, device=G.device)[None, None, :, None] - t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] - r_barrel = (r_barrel + (t_barrel + t0) // L) % R - G = G.gather(dim=2, index=r_barrel.expand_as(G)) + # r_barrel = torch.arange(R, device=G.device)[None, None, :, None] + # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] + # r_barrel = (r_barrel + (t_barrel + t0) // L) % R + # G = G.gather(dim=2, index=r_barrel.expand_as(G)) ###################################################################### # The "flashbacks" @@ -611,11 +624,6 @@ class Caterpillar(nn.Module): # We prepare the arguments for the parallel scan - # Clip the gating to avoid values greater than 1 when several - # heads hit the same row - - G = G / G.sum(1, keepdim=True).clamp(min=1) - A = 1 - G.sum(1) # warnings.warn("harmonic recurrence", RuntimeWarning) @@ -709,6 +717,8 @@ class QKVAttention(nn.Module): nb_heads=1, causal=False, attention_dropout=0.0, + logger=print, + **kwargs, ): super().__init__() @@ -800,6 +810,8 @@ class MyGPT(nn.Module): dropout=0.0, len_max=1e5, attention_layer="kvrec", + logger=print, + **kwargs, ): super().__init__() @@ -836,6 +848,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, causal=causal, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "dumbrec": return DumbRec( @@ -845,6 +859,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "kvrec": return KVRec( @@ -854,6 +870,8 @@ class MyGPT(nn.Module): nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, + logger=logger, + **kwargs, ) elif attention_layer == "caterpillar": return Caterpillar( @@ -864,6 +882,8 @@ class MyGPT(nn.Module): caterpillar_length=self.caterpillar_length, caterpillar_height=self.caterpillar_height, attention_dropout=dropout, + logger=logger, + **kwargs, ) else: raise ValueError(f"Unknown attention type {attention_layer}.") -- 2.39.5