+def add_memex_v1(batches, memex_proba, marker_token):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ t = (
+ torch.arange(1 + 2 * input.size(1), device=input.device)[None, :]
+ .expand(input.size(0), -1)
+ .clone()
+ )
+
+ u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
+ caterpillar_length = args.nb_lines // args.caterpillar_height
+ u1 = (
+ u0
+ + torch.randint(
+ caterpillar_length, (input.size(0), 1), device=input.device
+ )
+ + 1
+ )
+
+ m0 = (t < u0).long()
+ m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
+
+ t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
+ m = (t < 0).long()
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ new_input = input[n, t.clamp(min=0)]
+ new_input = (1 - m) * new_input + m * (marker_token)
+
+ memex_mask = new_input.new_zeros(new_input.size())
+ memex_mask[:, input.size(1) :] = 1.0
+
+ yield new_input, memex_mask
+
+ yield input
+
+
+# The marker token is not used for this one
+def add_memex_v2(batches, memex_proba, marker_token):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ t = torch.arange(input.size(1) // 4, device=input.device)[None, :].expand(
+ input.size(0), -1
+ )
+ t = t + torch.randint(
+ input.size(1) - t.size(1), (t.size(0), 1), device=t.device
+ )
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ flash = input[n, t]
+ new_input = torch.cat([input, flash], dim=1)
+
+ memex_mask = new_input.new_zeros(new_input.size())
+ memex_mask[:, input.size(1) :] = 1.0
+
+ yield new_input, memex_mask
+
+ else:
+ yield input
+
+
+def add_memex_v3(batches, memex_proba, marker_token):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ memex_len = input.size(1) // 4
+
+ t = torch.arange(input.size(1) + memex_len, device=input.device)[
+ None, :
+ ].expand(input.size(0), -1)
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ # Call me the tensor-spaghetti master
+
+ trigger = torch.rand(t.size(), device=t.device)
+ trigger[:, -memex_len:] = 2.0
+ trigger[:, 0] = 2.0
+ trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long()
+ memex_mask = trigger.clone()
+ memex_mask[:, memex_len:] -= trigger[:, :-memex_len]
+ memex_mask = memex_mask.cumsum(dim=1)
+
+ u = 1 - memex_mask
+ u[:, 0] = 0
+ u = u.cumsum(dim=1)
+ assert u.min() == 0
+ assert u.max() == input.size(1) - 1
+
+ v = (
+ (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
+ + torch.randint(
+ input.size(1) - memex_len, (input.size(0), 1), device=t.device
+ )
+ ) * memex_mask
+ assert v.min() >= 0
+ assert v.max() < input.size(1)
+ u = u * (1 - memex_mask) + v * memex_mask
+
+ new_input = input[n, u]
+ assert input.max() < vocabulary_size
+ assert new_input.max() < vocabulary_size
+ limits = trigger.clone()
+ limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
+ assert limits.min() == 0
+ assert limits.max() == 1
+ new_input = new_input * (1 - limits) + marker_token * limits
+ assert marker_token < vocabulary_size
+ assert new_input.max() < vocabulary_size
+
+ yield new_input, memex_mask
+
+ else:
+ yield input
+
+
+######################################################################
+