From a3abd0f58cfb2f2448c82db836093d20dc2954f2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 6 Jan 2024 14:42:19 +0100 Subject: [PATCH] Update. --- main.py | 2 +- mygpt.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index fabebdd..18c0730 100755 --- a/main.py +++ b/main.py @@ -478,7 +478,7 @@ def get_lr(n_epoch, it): if it < args.nb_warmup_iter: return args.legacy_large_lr * it / args.nb_warmup_iter - elif it < args.legacy_nb_epoch_large_lr: + elif n_epoch < args.legacy_nb_epoch_large_lr: return args.legacy_large_lr else: return args.legacy_small_lr diff --git a/mygpt.py b/mygpt.py index d1acf22..33c6fee 100755 --- a/mygpt.py +++ b/mygpt.py @@ -530,11 +530,11 @@ class Caterpillar(nn.Module): ###################################################################### # Compute the recurrent state - # This is the Gating sequence that modulates if they key and - # values should be stored in one of the CH pairs of the - # current stack. The CH gating values are independent, which - # means that the same thing could be stored up to CH times or - # not at all + # This is the Gating sequence that modulates the storing of + # the new key and value in the CH pairs of the current + # stack. The CH gating values are independent, which means + # that the current K/V could be stored in all the pairs of the + # recurrent state, or not at all. G = ( torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None] @@ -552,10 +552,11 @@ class Caterpillar(nn.Module): init_rec_V = self.rec_V[:, :, t0 - CL : t0] init_rec_K = self.rec_K[:, :, t0 - CL : t0] - # Here there is a trick: The parallel scan operates with a - # period of L, so we split the sequence indexing in two axes, - # the second of size CL, and run the parallel scan using the - # other alone as the sequence index. + # Here there is a trick: Since the stack at time t is computed + # by updating that at time t-L, the parallel scan operates + # with a period of L. To do so we split the time indexing in + # two axes, the second of size CL, and run the parallel scan + # using the other alone as the sequence index. A = A.unflatten(2, (-1, CL)) gated_V = gated_V.unflatten(2, (-1, CL)) -- 2.20.1