from torch import nn
from torch.nn import functional as F
+# torch.autograd.set_detect_anomaly(True) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
import ffutils
import mygpt, tasks, problems
########################################
-parser.add_argument("--nb_epochs", type=int, default=50)
+parser.add_argument("--nb_epochs", type=int, default=25)
+
+parser.add_argument("--physical_batch_size", type=int, default=None)
-parser.add_argument("--batch_size", type=int, default=None)
+parser.add_argument("--batch_size", type=int, default=25)
parser.add_argument("--nb_train_samples", type=int, default=None)
parser.add_argument("--attention", type=str, default=None)
+parser.add_argument("--memex_proba", type=float, default=0)
+
+parser.add_argument("--memex_nb_epochs", type=float, default=None)
+
parser.add_argument("--dim_model", type=int, default=None)
parser.add_argument("--dim_keys", type=int, default=None)
parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
-parser.add_argument("--gate_dropout_sync", type=bool, default=False)
+parser.add_argument("--gate_dropout_sync", type=str2bool, default=False)
+
+parser.add_argument("--gate_dropout_replace", type=str2bool, default=False)
parser.add_argument("--rho_inner_loss", type=float, default=0.0)
default_task_args = {
"addition": {
"model": "352M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"byheart": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
"expr": {
"model": "352M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 2500000,
"nb_test_samples": 10000,
},
"grid": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"qmlp": {
"model": "37M",
- "batch_size": 10,
+ "physical_batch_size": 10,
"nb_train_samples": 100000,
"nb_test_samples": 1000,
},
"guessop": {
"model": "352M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 1000000,
"nb_test_samples": 10000,
},
"learnop": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
"maze": {
"model": "37M",
- "batch_size": 5,
+ "physical_batch_size": 5,
"nb_train_samples": 100000,
"nb_test_samples": 10000,
},
"picoclvr": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"rpl": {
"model": "352M",
- "batch_size": 5,
+ "physical_batch_size": 5,
"nb_train_samples": 2500000,
"nb_test_samples": 10000,
},
"snake": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"stack": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 100000,
"nb_test_samples": 1000,
},
"twotargets": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
"memory": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 25000,
"nb_test_samples": 10000,
},
"mixing": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"mnist": {
"model": "37M",
- "batch_size": 10,
+ "physical_batch_size": 5,
"nb_train_samples": 60000,
"nb_test_samples": 10000,
},
exit(1)
loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
+lambda_file = open(os.path.join(args.result_dir, "lambda.dat"), "a")
log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
######################################################################
+def add_memex_v1(batches, memex_proba, marker_token):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ t = (
+ torch.arange(1 + 2 * input.size(1), device=input.device)[None, :]
+ .expand(input.size(0), -1)
+ .clone()
+ )
+
+ u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
+ caterpillar_length = args.nb_lines // args.caterpillar_height
+ u1 = (
+ u0
+ + torch.randint(
+ caterpillar_length, (input.size(0), 1), device=input.device
+ )
+ + 1
+ )
+
+ m0 = (t < u0).long()
+ m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
+
+ t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
+ m = (t < 0).long()
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ new_input = input[n, t.clamp(min=0)]
+ new_input = (1 - m) * new_input + m * (marker_token)
+
+ memex_mask = new_input.new_zeros(new_input.size())
+ memex_mask[:, input.size(1) :] = 1.0
+
+ yield new_input, memex_mask
+
+ yield input
+
+
+# The marker token is not used for this one
+def add_memex_v2(batches, memex_proba, marker_token):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ t = torch.arange(input.size(1) // 4, device=input.device)[None, :].expand(
+ input.size(0), -1
+ )
+ t = t + torch.randint(
+ input.size(1) - t.size(1), (t.size(0), 1), device=t.device
+ )
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ flash = input[n, t]
+ new_input = torch.cat([input, flash], dim=1)
+
+ memex_mask = new_input.new_zeros(new_input.size())
+ memex_mask[:, input.size(1) :] = 1.0
+
+ yield new_input, memex_mask
+
+ else:
+ yield input
+
+
+def add_memex_v3(batches, memex_proba, marker_token):
+ for input in batches:
+ memex_len = input.size(1) // 8
+
+ t = torch.arange(input.size(1) + memex_len, device=input.device)[
+ None, :
+ ].expand(input.size(0), -1)
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ t = (t - 1).clamp(min=0)
+
+ # Call me the tensor-spaghetti master
+
+ trigger = torch.rand(t.size(), device=t.device)
+ trigger[:, -memex_len:] = 2.0
+ trigger[:, : memex_len + 1] = 2.0
+ trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long()
+ memex_mask = trigger.clone()
+ memex_mask[:, memex_len:] -= trigger[:, :-memex_len]
+ memex_mask = memex_mask.cumsum(dim=1)
+
+ u = 1 - memex_mask
+ u[:, 0] = 0
+ u = u.cumsum(dim=1)
+
+ v = (
+ (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
+ + torch.randint(
+ input.size(1) - memex_len, (input.size(0), 1), device=t.device
+ )
+ ) * memex_mask
+ u = u * (1 - memex_mask) + v * memex_mask
+
+ new_input = input[n, u]
+ limits = trigger.clone()
+ limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
+ new_input = new_input * (1 - limits) + marker_token * limits
+ new_input[:, 0] = marker_token
+
+ orig = torch.cat(
+ [
+ input,
+ torch.full((input.size(0), memex_len), memex_marker, device=t.device),
+ ],
+ dim=1,
+ )
+
+ a = (torch.rand(input.size(0), 1, device=t.device) <= memex_proba).long()
+
+ new_input = (1 - a) * orig + a * new_input
+
+ yield new_input # memex_mask
+
+
+######################################################################
+
assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+assert args.batch_size % args.physical_batch_size == 0
+
def picoclvr_pruner_horizontal_green(p):
return not ("green" in p and ("left" in p or "right" in p))
problem=problems.ProblemByHeart(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemLearnOperator(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemGuessOperator(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemTwoTargets(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemMemory(len_total=args.memory_len_total),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemAddition(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
task = tasks.PicoCLVR(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
height=args.picoclvr_height,
width=args.picoclvr_width,
nb_colors=args.picoclvr_nb_colors,
task = tasks.MNIST(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
device=device_data,
)
task = tasks.Maze(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
height=args.maze_height,
width=args.maze_width,
nb_walls=args.maze_nb_walls,
task = tasks.Snake(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
height=args.snake_height,
width=args.snake_width,
nb_colors=args.snake_nb_colors,
task = tasks.Stack(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
nb_steps=args.stack_nb_steps,
nb_stacks=args.stack_nb_stacks,
sequence_length=args.expr_sequence_length,
operand_max=args.expr_operand_max,
result_max=args.expr_result_max,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
device=device_data,
)
task = tasks.RPL(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
nb_starting_values=args.rpl_nb_starting_values,
max_input=args.rpl_max_input,
prog_len=args.rpl_prog_len,
task = tasks.Grid(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
size=args.grid_size,
nb_shapes=args.grid_nb_shapes,
nb_colors=args.grid_nb_colors,
task = tasks.QMLP(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
result_dir=args.result_dir,
logger=log_string,
device=device_data,
vocabulary_size = task.vocabulary_size()
+if args.memex_proba > 0:
+ memex_marker = vocabulary_size
+ vocabulary_size += 1
+
log_string(f"vocabulary_size {vocabulary_size}")
##############################
n_batch = 0
+
+def the_dot_products(value1, value2, params):
+ g1g1, g1g2, g2g2 = 0, 0, 0
+ for p in params:
+ g1 = torch.autograd.grad(value1, p, retain_graph=True)[0]
+ g2 = torch.autograd.grad(value2, p, retain_graph=True)[0]
+ g1g1 += g1.pow(2).sum()[None]
+ g2g2 += g2.pow(2).sum()[None]
+ g1g2 += (g1 * g2).sum()[None]
+ return torch.cat([g1g1, g1g2, g2g2])
+
+
+def update_ave_grad(value, params, name, eps=1e-3):
+ for p in params:
+ g = torch.autograd.grad(value, p, retain_graph=True)[0]
+ ag = getattr(p, name) if hasattr(p, name) else 0
+ setattr(p, name, (1 - eps) * ag + eps * g)
+
+
+def norm(params, name):
+ s = 0
+ for p in params:
+ s += getattr(p, name).pow(2).sum()
+ return s
+
+
for n_epoch in range(nb_epochs_finished, nb_epochs):
if args.optim == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
- for input in task.batches(split="train"):
- model.reset_inner_loss()
- input = input.to(device)
-
- output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
- inner_loss = model.get_inner_loss()
-
- acc_train_loss += loss.item() * input.size(0)
- acc_train_inner_loss += inner_loss.item() * input.size(0)
+ memex_proba = (
+ args.memex_proba
+ if args.memex_nb_epochs is None or n_epoch < args.memex_nb_epochs
+ else 0.0
+ )
- nb_train_samples += input.size(0)
- nb_samples_seen += input.size(0)
+ log_string(f"memex_proba {memex_proba}")
- total_loss = loss + (
- args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
+ if args.memex_proba > 0:
+ warnings.warn("memex v3", RuntimeWarning)
+ train_batches = add_memex_v3(
+ batches=task.batches(split="train"),
+ memex_proba=memex_proba,
+ marker_token=memex_marker,
)
+ else:
+ train_batches = task.batches(split="train")
- it += 1
- lr = get_lr(n_epoch, it)
- for param_group in optimizer.param_groups:
- param_group["lr"] = lr
+ def add_none(it):
+ for x in it:
+ yield x
+ yield None
- # log_string(f"learning_rate {lr}")
+ nb_acc_samples = 0
- optimizer.zero_grad()
- total_loss.backward()
- optimizer.step()
+ for input in add_none(train_batches):
+ if input is not None:
+ if type(input) is tuple:
+ input, memex_mask = input
+ memex_mask = memex_mask.to(device)
+ else:
+ memex_mask = None
- grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+ model.reset_inner_loss()
+ input = input.to(device)
- loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+ output = model(mygpt.BracketedSequence(input)).x
- n_batch += 1
+ if memex_mask is None:
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ else:
+ loss = F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+ loss_regular = (loss * (1 - memex_mask)).mean()
+ loss_memex = (loss * memex_mask).mean()
+
+ if it < 100 or torch.rand(1) < 0.01:
+ update_ave_grad(loss_regular, model.parameters(), "grad_regular")
+ update_ave_grad(loss_memex, model.parameters(), "grad_memex")
+ norm_regular = norm(model.parameters(), "grad_regular")
+ norm_memex = norm(model.parameters(), "grad_memex")
+ l_memex = (
+ max(norm_regular, norm_memex) - norm_regular
+ ) / norm_memex
+
+ loss = loss_regular + l_memex * loss_memex
+
+ inner_loss = model.get_inner_loss()
+
+ acc_train_loss += loss.item() * input.size(0)
+ acc_train_inner_loss += inner_loss.item() * input.size(0)
+
+ nb_train_samples += input.size(0)
+ nb_samples_seen += input.size(0)
+
+ total_loss = loss + (
+ args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
+ )
+
+ it += 1
+ lr = get_lr(n_epoch, it)
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+
+ # log_string(f"learning_rate {lr}")
+
+ total_loss.backward()
+ nb_acc_samples += input.size(0)
+
+ if (input is None and nb_acc_samples > 0) or nb_acc_samples == args.batch_size:
+ assert nb_acc_samples <= args.batch_size
+ optimizer.step()
+ grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+ loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+ if memex_mask is not None:
+ lambda_file.write(
+ f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n"
+ )
+ optimizer.zero_grad()
+ nb_acc_samples = 0
+ n_batch += 1
with torch.autograd.no_grad():
model.eval()