Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 20 Feb 2024 07:52:38 +0000 (08:52 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 20 Feb 2024 07:52:38 +0000 (08:52 +0100)
main.py

diff --git a/main.py b/main.py
index 2a90fd1..00b8301 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -570,7 +570,8 @@ def add_memex_v1(batches, memex_proba, marker_token):
         yield input
 
 
-def add_memex_v2(batches, memex_proba):
+# The marker token is not used for this one
+def add_memex_v2(batches, memex_proba, marker_token):
     for input in batches:
         if torch.rand(1).item() < memex_proba:
             t = torch.arange(input.size(1) // 4, device=input.device)[None, :].expand(
@@ -595,6 +596,47 @@ def add_memex_v2(batches, memex_proba):
             yield input
 
 
+def add_memex_v3(batches, memex_proba, marker_token):
+    for input in batches:
+        if torch.rand(1).item() < memex_proba:
+            memex_len = input.size(1) // 4
+
+            t = torch.arange(input.size(1) + memex_len, device=input.device)[
+                None, :
+            ].expand(input.size(0), -1)
+
+            # Call me the tensor-spaghetti master
+
+            trigger = torch.rand(t.size(), device=t.device)
+            trigger[:, -memex_len:] = 1.0
+            trigger = (trigger.sort(dim=1).indices == 0).long()
+            memex_mask = trigger.clone()
+            memex_mask[:, memex_len:] -= memex_mask[:, :-memex_len]
+            memex_mask = memex_mask.cumsum(dim=1)
+            u = 1 - memex_mask
+            u[:, 0] = 0
+            u = u.cumsum(dim=1)
+            # assert u.min() == 0
+            # assert u.max() == input.size(1) - 1
+            v = (
+                (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
+                + torch.randint(input.size(1), (input.size(0), 1), device=t.device)
+            ) * memex_mask
+            u = u * (1 - memex_mask) + v * memex_mask
+            n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+                -1, t.size(1)
+            )
+            new_input = input[n, u]
+            limits = trigger.clone()
+            limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
+            new_input = new_input * (1 - limits) + memex_marker * limits
+
+            yield new_input, memex_mask
+
+        else:
+            yield input
+
+
 ######################################################################
 
 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
@@ -814,6 +856,7 @@ log_string(f"device {device}")
 vocabulary_size = task.vocabulary_size()
 
 if args.memex_proba > 0:
+    memex_marker = vocabulary_size
     vocabulary_size += 1
 
 log_string(f"vocabulary_size {vocabulary_size}")
@@ -975,7 +1018,18 @@ def the_dot_products(value1, value2, params):
     return torch.cat([g1g1, g1g2, g2g2])
 
 
-movave_dot_products = 0
+def update_ave_grad(value, params, name, eps=1e-3):
+    for p in params:
+        g = torch.autograd.grad(value, p, retain_graph=True)[0]
+        ag = getattr(p, name) if hasattr(p, name) else 0
+        setattr(p, name, (1 - eps) * ag + eps * g)
+
+
+def norm(params, name):
+    s = 0
+    for p in params:
+        s += getattr(p, name).pow(2).sum()
+    return s
 
 
 for n_epoch in range(nb_epochs_finished, nb_epochs):
@@ -1000,9 +1054,11 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
 
     log_string(f"memex_proba {memex_proba}")
 
-    train_batches = add_memex_v2(
+    warnings.warn("memex v3", RuntimeWarning)
+    train_batches = add_memex_v3(
         batches=task.batches(split="train"),
         memex_proba=memex_proba,
+        marker_token=memex_marker,
     )
 
     def add_none(it):
@@ -1032,18 +1088,16 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
                 loss_regular = (loss * (1 - memex_mask)).mean()
                 loss_memex = (loss * memex_mask).mean()
 
-                if not torch.is_tensor(movave_dot_products) or torch.rand(1) < 0.01:
-                    dot_products = the_dot_products(
-                        loss_regular, loss_memex, model.parameters()
-                    )
-                    eps = 1e-3
-                    movave_dot_products = (
-                        1 - eps
-                    ) * movave_dot_products + eps * dot_products
+                if it < 100 or torch.rand(1) < 0.01:
+                    update_ave_grad(loss_regular, model.parameters(), "grad_regular")
+                    update_ave_grad(loss_memex, model.parameters(), "grad_memex")
+                    norm_regular = norm(model.parameters(), "grad_regular")
+                    norm_memex = norm(model.parameters(), "grad_memex")
+                    l_memex = (
+                        max(norm_regular, norm_memex) - norm_regular
+                    ) / norm_memex
 
-                grgr, grgm, gmgm = movave_dot_products
-                l = (max(grgr, gmgm) - grgr) / gmgm
-                loss = loss_regular + l * loss_memex
+                loss = loss_regular + l_memex * loss_memex
 
             inner_loss = model.get_inner_loss()
 
@@ -1072,9 +1126,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
             optimizer.step()
             grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
             loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
-            grgr, grgm, gmgm = movave_dot_products
-            l = (max(grgr, rho * gmgm) - grgr) / (rho * gmgm)
-            lambda_file.write(f"{n_epoch} {n_batch} {l} {grgr} {gmgm}\n")
+            lambda_file.write(
+                f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n"
+            )
             optimizer.zero_grad()
             nb_acc_samples = 0
             n_batch += 1