yield input
-def add_memex_v2(batches, memex_proba):
+# The marker token is not used for this one
+def add_memex_v2(batches, memex_proba, marker_token):
for input in batches:
if torch.rand(1).item() < memex_proba:
t = torch.arange(input.size(1) // 4, device=input.device)[None, :].expand(
yield input
+def add_memex_v3(batches, memex_proba, marker_token):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ memex_len = input.size(1) // 4
+
+ t = torch.arange(input.size(1) + memex_len, device=input.device)[
+ None, :
+ ].expand(input.size(0), -1)
+
+ # Call me the tensor-spaghetti master
+
+ trigger = torch.rand(t.size(), device=t.device)
+ trigger[:, -memex_len:] = 1.0
+ trigger = (trigger.sort(dim=1).indices == 0).long()
+ memex_mask = trigger.clone()
+ memex_mask[:, memex_len:] -= memex_mask[:, :-memex_len]
+ memex_mask = memex_mask.cumsum(dim=1)
+ u = 1 - memex_mask
+ u[:, 0] = 0
+ u = u.cumsum(dim=1)
+ # assert u.min() == 0
+ # assert u.max() == input.size(1) - 1
+ v = (
+ (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
+ + torch.randint(input.size(1), (input.size(0), 1), device=t.device)
+ ) * memex_mask
+ u = u * (1 - memex_mask) + v * memex_mask
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+ new_input = input[n, u]
+ limits = trigger.clone()
+ limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
+ new_input = new_input * (1 - limits) + memex_marker * limits
+
+ yield new_input, memex_mask
+
+ else:
+ yield input
+
+
######################################################################
assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
vocabulary_size = task.vocabulary_size()
if args.memex_proba > 0:
+ memex_marker = vocabulary_size
vocabulary_size += 1
log_string(f"vocabulary_size {vocabulary_size}")
return torch.cat([g1g1, g1g2, g2g2])
-movave_dot_products = 0
+def update_ave_grad(value, params, name, eps=1e-3):
+ for p in params:
+ g = torch.autograd.grad(value, p, retain_graph=True)[0]
+ ag = getattr(p, name) if hasattr(p, name) else 0
+ setattr(p, name, (1 - eps) * ag + eps * g)
+
+
+def norm(params, name):
+ s = 0
+ for p in params:
+ s += getattr(p, name).pow(2).sum()
+ return s
for n_epoch in range(nb_epochs_finished, nb_epochs):
log_string(f"memex_proba {memex_proba}")
- train_batches = add_memex_v2(
+ warnings.warn("memex v3", RuntimeWarning)
+ train_batches = add_memex_v3(
batches=task.batches(split="train"),
memex_proba=memex_proba,
+ marker_token=memex_marker,
)
def add_none(it):
loss_regular = (loss * (1 - memex_mask)).mean()
loss_memex = (loss * memex_mask).mean()
- if not torch.is_tensor(movave_dot_products) or torch.rand(1) < 0.01:
- dot_products = the_dot_products(
- loss_regular, loss_memex, model.parameters()
- )
- eps = 1e-3
- movave_dot_products = (
- 1 - eps
- ) * movave_dot_products + eps * dot_products
+ if it < 100 or torch.rand(1) < 0.01:
+ update_ave_grad(loss_regular, model.parameters(), "grad_regular")
+ update_ave_grad(loss_memex, model.parameters(), "grad_memex")
+ norm_regular = norm(model.parameters(), "grad_regular")
+ norm_memex = norm(model.parameters(), "grad_memex")
+ l_memex = (
+ max(norm_regular, norm_memex) - norm_regular
+ ) / norm_memex
- grgr, grgm, gmgm = movave_dot_products
- l = (max(grgr, gmgm) - grgr) / gmgm
- loss = loss_regular + l * loss_memex
+ loss = loss_regular + l_memex * loss_memex
inner_loss = model.get_inner_loss()
optimizer.step()
grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
- grgr, grgm, gmgm = movave_dot_products
- l = (max(grgr, rho * gmgm) - grgr) / (rho * gmgm)
- lambda_file.write(f"{n_epoch} {n_batch} {l} {grgr} {gmgm}\n")
+ lambda_file.write(
+ f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n"
+ )
optimizer.zero_grad()
nb_acc_samples = 0
n_batch += 1