Update. master
authorFrançois Fleuret <francois@fleuret.org>
Sat, 31 Aug 2024 07:50:05 +0000 (09:50 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 31 Aug 2024 07:50:05 +0000 (09:50 +0200)
main.py
pscan.py

diff --git a/main.py b/main.py
index a587e96..88f56b3 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -598,57 +598,58 @@ def add_memex_v2(batches, memex_proba, marker_token):
 
 def add_memex_v3(batches, memex_proba, marker_token):
     for input in batches:
-        if torch.rand(1).item() < memex_proba:
-            memex_len = input.size(1) // 4
+        memex_len = input.size(1) // 8
 
-            t = torch.arange(input.size(1) + memex_len, device=input.device)[
-                None, :
-            ].expand(input.size(0), -1)
-            n = torch.arange(input.size(0), device=input.device)[:, None].expand(
-                -1, t.size(1)
-            )
+        t = torch.arange(input.size(1) + memex_len, device=input.device)[
+            None, :
+        ].expand(input.size(0), -1)
+        n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+            -1, t.size(1)
+        )
 
-            # Call me the tensor-spaghetti master
+        t = (t - 1).clamp(min=0)
 
-            trigger = torch.rand(t.size(), device=t.device)
-            trigger[:, -memex_len:] = 2.0
-            trigger[:, 0] = 2.0
-            trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long()
-            memex_mask = trigger.clone()
-            memex_mask[:, memex_len:] -= trigger[:, :-memex_len]
-            memex_mask = memex_mask.cumsum(dim=1)
+        # Call me the tensor-spaghetti master
 
-            u = 1 - memex_mask
-            u[:, 0] = 0
-            u = u.cumsum(dim=1)
-            assert u.min() == 0
-            assert u.max() == input.size(1) - 1
+        trigger = torch.rand(t.size(), device=t.device)
+        trigger[:, -memex_len:] = 2.0
+        trigger[:, : memex_len + 1] = 2.0
+        trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long()
+        memex_mask = trigger.clone()
+        memex_mask[:, memex_len:] -= trigger[:, :-memex_len]
+        memex_mask = memex_mask.cumsum(dim=1)
 
-            v = (
-                (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
-                + torch.randint(
-                    input.size(1) - memex_len, (input.size(0), 1), device=t.device
-                )
-            ) * memex_mask
-            assert v.min() >= 0
-            assert v.max() < input.size(1)
-            u = u * (1 - memex_mask) + v * memex_mask
-
-            new_input = input[n, u]
-            assert input.max() < vocabulary_size
-            assert new_input.max() < vocabulary_size
-            limits = trigger.clone()
-            limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
-            assert limits.min() == 0
-            assert limits.max() == 1
-            new_input = new_input * (1 - limits) + marker_token * limits
-            assert marker_token < vocabulary_size
-            assert new_input.max() < vocabulary_size
+        u = 1 - memex_mask
+        u[:, 0] = 0
+        u = u.cumsum(dim=1)
 
-            yield new_input, memex_mask
+        v = (
+            (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
+            + torch.randint(
+                input.size(1) - memex_len, (input.size(0), 1), device=t.device
+            )
+        ) * memex_mask
+        u = u * (1 - memex_mask) + v * memex_mask
+
+        new_input = input[n, u]
+        limits = trigger.clone()
+        limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
+        new_input = new_input * (1 - limits) + marker_token * limits
+        new_input[:, 0] = marker_token
+
+        orig = torch.cat(
+            [
+                input,
+                torch.full((input.size(0), memex_len), memex_marker, device=t.device),
+            ],
+            dim=1,
+        )
 
-        else:
-            yield input
+        a = (torch.rand(input.size(0), 1, device=t.device) <= memex_proba).long()
+
+        new_input = (1 - a) * orig + a * new_input
+
+        yield new_input  # memex_mask
 
 
 ######################################################################
@@ -1068,12 +1069,15 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
 
     log_string(f"memex_proba {memex_proba}")
 
-    warnings.warn("memex v3", RuntimeWarning)
-    train_batches = add_memex_v3(
-        batches=task.batches(split="train"),
-        memex_proba=memex_proba,
-        marker_token=memex_marker,
-    )
+    if args.memex_proba > 0:
+        warnings.warn("memex v3", RuntimeWarning)
+        train_batches = add_memex_v3(
+            batches=task.batches(split="train"),
+            memex_proba=memex_proba,
+            marker_token=memex_marker,
+        )
+    else:
+        train_batches = task.batches(split="train")
 
     def add_none(it):
         for x in it:
@@ -1140,9 +1144,10 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
             optimizer.step()
             grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
             loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
-            lambda_file.write(
-                f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n"
-            )
+            if memex_mask is not None:
+                lambda_file.write(
+                    f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n"
+                )
             optimizer.zero_grad()
             nb_acc_samples = 0
             n_batch += 1
index 0bb0d14..b533164 100755 (executable)
--- a/pscan.py
+++ b/pscan.py
@@ -124,17 +124,17 @@ if __name__ == "__main__":
 
     ######################################################################
 
-    N, T, D = 16, 4096, 32
+    N, T, D = 16, 4096, 32
 
-    for r in range(timing.size(0)):
-        A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_()
-        X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
-        Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
+    for r in range(timing.size(0)):
+    # A = 0.9 + 0.1 * torch.rand(N, T, dtype=torch.float64).requires_grad_()
+    # X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+    # Y_init = torch.randn(N, D, dtype=torch.float64).requires_grad_()
 
-        start_time = time.perf_counter()
-        for _ in range(1000):
-            Y = pscan(A, X, Y_init)
-        duration = time.perf_counter() - start_time
+    # start_time = time.perf_counter()
+    # for _ in range(1000):
+    # Y = pscan(A, X, Y_init)
+    # duration = time.perf_counter() - start_time
 
     ######################################################################