if it < args.nb_warmup_iter:
return args.legacy_large_lr * it / args.nb_warmup_iter
- elif it < args.legacy_nb_epoch_large_lr:
+ elif n_epoch < args.legacy_nb_epoch_large_lr:
return args.legacy_large_lr
else:
return args.legacy_small_lr
######################################################################
# Compute the recurrent state
- # This is the Gating sequence that modulates if they key and
- # values should be stored in one of the CH pairs of the
- # current stack. The CH gating values are independent, which
- # means that the same thing could be stored up to CH times or
- # not at all
+ # This is the Gating sequence that modulates the storing of
+ # the new key and value in the CH pairs of the current
+ # stack. The CH gating values are independent, which means
+ # that the current K/V could be stored in all the pairs of the
+ # recurrent state, or not at all.
G = (
torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None]
init_rec_V = self.rec_V[:, :, t0 - CL : t0]
init_rec_K = self.rec_K[:, :, t0 - CL : t0]
- # Here there is a trick: The parallel scan operates with a
- # period of L, so we split the sequence indexing in two axes,
- # the second of size CL, and run the parallel scan using the
- # other alone as the sequence index.
+ # Here there is a trick: Since the stack at time t is computed
+ # by updating that at time t-L, the parallel scan operates
+ # with a period of L. To do so we split the time indexing in
+ # two axes, the second of size CL, and run the parallel scan
+ # using the other alone as the sequence index.
A = A.unflatten(2, (-1, CL))
gated_V = gated_V.unflatten(2, (-1, CL))