exit(1)
loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
+lambda_file = open(os.path.join(args.result_dir, "lambda.dat"), "a")
log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
######################################################################
-def add_memex_v2(batches, memex_proba, marker_token):
+def add_memex_v1(batches, memex_proba, marker_token):
for input in batches:
if torch.rand(1).item() < memex_proba:
t = (
new_input = input[n, t.clamp(min=0)]
new_input = (1 - m) * new_input + m * (marker_token)
- yield new_input
+ memex_mask = new_input.new_zeros(new_input.size())
+ memex_mask[:, input.size(1) :] = 1.0
+
+ yield new_input, memex_mask
yield input
-def add_memex_v3(batches, memex_proba, marker_token):
+# The marker token is not used for this one
+def add_memex_v2(batches, memex_proba, marker_token):
for input in batches:
if torch.rand(1).item() < memex_proba:
- t = (
- torch.arange(2 * input.size(1), device=input.device)[None, :]
- .expand(input.size(0), -1)
- .clone()
+ t = torch.arange(input.size(1) // 4, device=input.device)[None, :].expand(
+ input.size(0), -1
+ )
+ t = t + torch.randint(
+ input.size(1) - t.size(1), (t.size(0), 1), device=t.device
+ )
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
)
- u = torch.rand(t.size(), device=t.device)
- u[:, : input.size(1)] = 1.0
- memex_v3_proba_fragment = 1 / 20
- u = (u < memex_v3_proba_fragment).long()
- v = u * torch.randint(input.size(1), u.size())
- u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[
- :, : input.size(1) - 1
- ] * input.size(1)
- u = u.cumsum().clamp(min=0)
+ flash = input[n, t]
+ new_input = torch.cat([input, flash], dim=1)
- u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
- caterpillar_length = args.nb_lines // args.caterpillar_height
- u1 = (
- u0
- + torch.randint(
- caterpillar_length, (input.size(0), 1), device=input.device
- )
- + 1
- )
+ memex_mask = new_input.new_zeros(new_input.size())
+ memex_mask[:, input.size(1) :] = 1.0
- m0 = (t < u0).long()
- m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
+ yield new_input, memex_mask
- t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
- m = (t < 0).long()
- n = torch.arange(input.size(0), device=input.device)[:, None].expand(
- -1, t.size(1)
+ else:
+ yield input
+
+
+def add_memex_v3(batches, memex_proba, marker_token):
+ for input in batches:
+ memex_len = input.size(1) // 8
+
+ t = torch.arange(input.size(1) + memex_len, device=input.device)[
+ None, :
+ ].expand(input.size(0), -1)
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ t = (t - 1).clamp(min=0)
+
+ # Call me the tensor-spaghetti master
+
+ trigger = torch.rand(t.size(), device=t.device)
+ trigger[:, -memex_len:] = 2.0
+ trigger[:, : memex_len + 1] = 2.0
+ trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long()
+ memex_mask = trigger.clone()
+ memex_mask[:, memex_len:] -= trigger[:, :-memex_len]
+ memex_mask = memex_mask.cumsum(dim=1)
+
+ u = 1 - memex_mask
+ u[:, 0] = 0
+ u = u.cumsum(dim=1)
+
+ v = (
+ (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
+ + torch.randint(
+ input.size(1) - memex_len, (input.size(0), 1), device=t.device
)
+ ) * memex_mask
+ u = u * (1 - memex_mask) + v * memex_mask
+
+ new_input = input[n, u]
+ limits = trigger.clone()
+ limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
+ new_input = new_input * (1 - limits) + marker_token * limits
+ new_input[:, 0] = marker_token
+
+ orig = torch.cat(
+ [
+ input,
+ torch.full((input.size(0), memex_len), memex_marker, device=t.device),
+ ],
+ dim=1,
+ )
- new_input = input[n, t.clamp(min=0)]
- new_input = (1 - m) * new_input + m * (marker_token)
+ a = (torch.rand(input.size(0), 1, device=t.device) <= memex_proba).long()
- yield new_input
+ new_input = (1 - a) * orig + a * new_input
- yield input
+ yield new_input # memex_mask
######################################################################
assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+assert args.batch_size % args.physical_batch_size == 0
+
def picoclvr_pruner_horizontal_green(p):
return not ("green" in p and ("left" in p or "right" in p))
vocabulary_size = task.vocabulary_size()
if args.memex_proba > 0:
+ memex_marker = vocabulary_size
vocabulary_size += 1
log_string(f"vocabulary_size {vocabulary_size}")
n_batch = 0
+
+def the_dot_products(value1, value2, params):
+ g1g1, g1g2, g2g2 = 0, 0, 0
+ for p in params:
+ g1 = torch.autograd.grad(value1, p, retain_graph=True)[0]
+ g2 = torch.autograd.grad(value2, p, retain_graph=True)[0]
+ g1g1 += g1.pow(2).sum()[None]
+ g2g2 += g2.pow(2).sum()[None]
+ g1g2 += (g1 * g2).sum()[None]
+ return torch.cat([g1g1, g1g2, g2g2])
+
+
+def update_ave_grad(value, params, name, eps=1e-3):
+ for p in params:
+ g = torch.autograd.grad(value, p, retain_graph=True)[0]
+ ag = getattr(p, name) if hasattr(p, name) else 0
+ setattr(p, name, (1 - eps) * ag + eps * g)
+
+
+def norm(params, name):
+ s = 0
+ for p in params:
+ s += getattr(p, name).pow(2).sum()
+ return s
+
+
for n_epoch in range(nb_epochs_finished, nb_epochs):
if args.optim == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
log_string(f"memex_proba {memex_proba}")
- train_batches = add_memex_v2(
- batches=task.batches(split="train"),
- memex_proba=memex_proba,
- marker_token=vocabulary_size - 1,
- )
+ if args.memex_proba > 0:
+ warnings.warn("memex v3", RuntimeWarning)
+ train_batches = add_memex_v3(
+ batches=task.batches(split="train"),
+ memex_proba=memex_proba,
+ marker_token=memex_marker,
+ )
+ else:
+ train_batches = task.batches(split="train")
def add_none(it):
for x in it:
for input in add_none(train_batches):
if input is not None:
+ if type(input) is tuple:
+ input, memex_mask = input
+ memex_mask = memex_mask.to(device)
+ else:
+ memex_mask = None
+
model.reset_inner_loss()
input = input.to(device)
output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
+
+ if memex_mask is None:
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ else:
+ loss = F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+ loss_regular = (loss * (1 - memex_mask)).mean()
+ loss_memex = (loss * memex_mask).mean()
+
+ if it < 100 or torch.rand(1) < 0.01:
+ update_ave_grad(loss_regular, model.parameters(), "grad_regular")
+ update_ave_grad(loss_memex, model.parameters(), "grad_memex")
+ norm_regular = norm(model.parameters(), "grad_regular")
+ norm_memex = norm(model.parameters(), "grad_memex")
+ l_memex = (
+ max(norm_regular, norm_memex) - norm_regular
+ ) / norm_memex
+
+ loss = loss_regular + l_memex * loss_memex
+
inner_loss = model.get_inner_loss()
acc_train_loss += loss.item() * input.size(0)
optimizer.step()
grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+ if memex_mask is not None:
+ lambda_file.write(
+ f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n"
+ )
optimizer.zero_grad()
nb_acc_samples = 0
-
- n_batch += 1
+ n_batch += 1
with torch.autograd.no_grad():
model.eval()