) % k_star.size(0)
k_star = k_star[l_barrel, t_barrel]
+
+######################################################################
+
+2024 Feb 15 23:10:50 (from main.py)
+
+
+def add_memex_v4(batches, memex_proba, marker_token):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ t = (
+ torch.arange(2 * input.size(1), device=input.device)[None, :]
+ .expand(input.size(0), -1)
+ .clone()
+ )
+
+ u = torch.rand(t.size(), device=t.device)
+ u[:, : input.size(1)] = 1.0
+ memex_v3_proba_fragment = 1 / 20
+ u = (u < memex_v3_proba_fragment).long()
+ v = u * torch.randint(input.size(1), u.size())
+ u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[
+ :, : input.size(1) - 1
+ ] * input.size(1)
+ u = u.cumsum().clamp(min=0)
+
+ u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
+ caterpillar_length = args.nb_lines // args.caterpillar_height
+ u1 = (
+ u0
+ + torch.randint(
+ caterpillar_length, (input.size(0), 1), device=input.device
+ )
+ + 1
+ )
+
+ m0 = (t < u0).long()
+ m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
+
+ t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
+ m = (t < 0).long()
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ new_input = input[n, t.clamp(min=0)]
+ new_input = (1 - m) * new_input + m * (marker_token)
+
+ yield new_input
+
+ yield input
+
+
+
+######################################################################
+
+2024 Feb 16 17:07:48 (from main.py)
+
+ # ||gn + lambda * gm|| = max(||gn||,||gm||)
+ # ||gn||^2 + lambda<gn,gm> + lambda^2||gm||^2 = max(||gn||^2,||gm||^2)
+ # A = ||gm||^2 B = <gn,gm> C = ||gn||^2 - max(||gn||^2, ||gm||^2)
+
+######################################################################
+
+2024 Feb 16 17:07:51 (from main.py)
+
+ # A,B,C = gmgm, gngm, gngn - max(gngn,gmgm)
+ # Delta = B*B - 4*A*C
+ # if(delta >= 0):
+ # l = ( -B - sqrt(Delta))/(2*A)
+ # ||gn||+l*rho*||gm|| = max(||gn||,rho*||gm||)
exit(1)
loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
+lambda_file = open(os.path.join(args.result_dir, "lambda.dat"), "a")
log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
######################################################################
-def add_memex_v2(batches, memex_proba, marker_token):
+def add_memex_v1(batches, memex_proba, marker_token):
for input in batches:
if torch.rand(1).item() < memex_proba:
t = (
new_input = input[n, t.clamp(min=0)]
new_input = (1 - m) * new_input + m * (marker_token)
- yield new_input
+ memex_mask = new_input.new_zeros(new_input.size())
+ memex_mask[:, input.size(1) :] = 1.0
+
+ yield new_input, memex_mask
yield input
-def add_memex_v3(batches, memex_proba, marker_token):
+def add_memex_v2(batches, memex_proba):
for input in batches:
if torch.rand(1).item() < memex_proba:
- t = (
- torch.arange(2 * input.size(1), device=input.device)[None, :]
- .expand(input.size(0), -1)
- .clone()
+ t = torch.arange(input.size(1) // 4, device=input.device)[None, :].expand(
+ input.size(0), -1
)
-
- u = torch.rand(t.size(), device=t.device)
- u[:, : input.size(1)] = 1.0
- memex_v3_proba_fragment = 1 / 20
- u = (u < memex_v3_proba_fragment).long()
- v = u * torch.randint(input.size(1), u.size())
- u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[
- :, : input.size(1) - 1
- ] * input.size(1)
- u = u.cumsum().clamp(min=0)
-
- u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
- caterpillar_length = args.nb_lines // args.caterpillar_height
- u1 = (
- u0
- + torch.randint(
- caterpillar_length, (input.size(0), 1), device=input.device
- )
- + 1
+ t = t + torch.randint(
+ input.size(1) - t.size(1), (t.size(0), 1), device=t.device
)
-
- m0 = (t < u0).long()
- m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
-
- t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
- m = (t < 0).long()
n = torch.arange(input.size(0), device=input.device)[:, None].expand(
-1, t.size(1)
)
- new_input = input[n, t.clamp(min=0)]
- new_input = (1 - m) * new_input + m * (marker_token)
+ flash = input[n, t]
+ new_input = torch.cat([input, flash], dim=1)
- yield new_input
+ memex_mask = new_input.new_zeros(new_input.size())
+ memex_mask[:, input.size(1) :] = 1.0
- yield input
+ yield new_input, memex_mask
+
+ else:
+ yield input
######################################################################
assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+assert args.batch_size % args.physical_batch_size == 0
+
def picoclvr_pruner_horizontal_green(p):
return not ("green" in p and ("left" in p or "right" in p))
n_batch = 0
+
+def the_dot_products(value1, value2, params):
+ g1g1, g1g2, g2g2 = 0, 0, 0
+ for p in params:
+ g1 = torch.autograd.grad(value1, p, retain_graph=True)[0]
+ g2 = torch.autograd.grad(value2, p, retain_graph=True)[0]
+ g1g1 += g1.pow(2).sum()[None]
+ g2g2 += g2.pow(2).sum()[None]
+ g1g2 += (g1 * g2).sum()[None]
+ return torch.cat([g1g1, g1g2, g2g2])
+
+
+movave_dot_products = 0
+
+
for n_epoch in range(nb_epochs_finished, nb_epochs):
if args.optim == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
train_batches = add_memex_v2(
batches=task.batches(split="train"),
memex_proba=memex_proba,
- marker_token=vocabulary_size - 1,
)
def add_none(it):
for input in add_none(train_batches):
if input is not None:
+ if type(input) is tuple:
+ input, memex_mask = input
+ memex_mask = memex_mask.to(device)
+ else:
+ memex_mask = None
+
model.reset_inner_loss()
input = input.to(device)
output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
+
+ if memex_mask is None:
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ else:
+ loss = F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+ loss_regular = (loss * (1 - memex_mask)).mean()
+ loss_memex = (loss * memex_mask).mean()
+
+ if not torch.is_tensor(movave_dot_products) or torch.rand(1) < 0.01:
+ dot_products = the_dot_products(
+ loss_regular, loss_memex, model.parameters()
+ )
+ eps = 1e-3
+ movave_dot_products = (
+ 1 - eps
+ ) * movave_dot_products + eps * dot_products
+
+ grgr, grgm, gmgm = movave_dot_products
+ l = (max(grgr, gmgm) - grgr) / gmgm
+ loss = loss_regular + l * loss_memex
+
inner_loss = model.get_inner_loss()
acc_train_loss += loss.item() * input.size(0)
optimizer.step()
grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+ grgr, grgm, gmgm = movave_dot_products
+ l = (max(grgr, rho * gmgm) - grgr) / (rho * gmgm)
+ lambda_file.write(f"{n_epoch} {n_batch} {l} {grgr} {gmgm}\n")
optimizer.zero_grad()
nb_acc_samples = 0
-
- n_batch += 1
+ n_batch += 1
with torch.autograd.no_grad():
model.eval()