from torch import nn
from torch.nn import functional as F
+# torch.autograd.set_detect_anomaly(True) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
import ffutils
import mygpt, tasks, problems
######################################################################
-if torch.cuda.is_available():
- device = torch.device("cuda")
- torch.backends.cuda.matmul.allow_tf32 = True
-else:
- device = torch.device("cpu")
-######################################################################
+def str2bool(x):
+ x = x.lower()
+ if x in {"1", "true", "yes"}:
+ return True
+ elif x in {"0", "false", "no"}:
+ return False
+ else:
+ raise ValueError
+
parser = argparse.ArgumentParser(
description="An implementation of GPT with cache.",
parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
+parser.add_argument("--force_cpu", type=str2bool, default=False)
+
########################################
-parser.add_argument("--nb_epochs", type=int, default=50)
+parser.add_argument("--nb_epochs", type=int, default=25)
-parser.add_argument("--batch_size", type=int, default=None)
+parser.add_argument("--physical_batch_size", type=int, default=None)
+
+parser.add_argument("--batch_size", type=int, default=25)
parser.add_argument("--nb_train_samples", type=int, default=None)
# legacy
-parser.add_argument("--legacy_lr_schedule", action="store_true", default=False)
+parser.add_argument("--legacy_lr_schedule", type=str2bool, default=True)
parser.add_argument("--legacy_large_lr", type=float, default=1e-4)
parser.add_argument("--attention", type=str, default=None)
+parser.add_argument("--memex_proba", type=float, default=0)
+
+parser.add_argument("--memex_nb_epochs", type=float, default=None)
+
parser.add_argument("--dim_model", type=int, default=None)
parser.add_argument("--dim_keys", type=int, default=None)
parser.add_argument("--caterpillar_height", type=int, default=None)
-parser.add_argument("--rho", type=float, default=0.0)
+parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
+
+parser.add_argument("--gate_dropout_sync", type=str2bool, default=False)
+
+parser.add_argument("--gate_dropout_replace", type=str2bool, default=False)
-parser.add_argument("--dim_rec_v", type=int, default=None)
+parser.add_argument("--rho_inner_loss", type=float, default=0.0)
parser.add_argument("--nb_blocks", type=int, default=None)
parser.add_argument("--no_checkpoint", action="store_true", default=False)
-parser.add_argument("--overwrite_results", action="store_true", default=False)
+parser.add_argument("--continue_training", action="store_true", default=False)
parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
parser.add_argument("--grid_size", type=int, default=6)
+parser.add_argument("--grid_nb_colors", type=int, default=6)
+
+parser.add_argument("--grid_nb_shapes", type=int, default=6)
+
##############################
# picoclvr options
######################################################################
-args = parser.parse_args()
+# args = parser.parse_args()
-assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+args, sup_args = parser.parse_known_args()
+
+sup_args = dict([x.removeprefix("--").split("=") for x in sup_args])
if args.result_dir is None:
args.result_dir = f"results_{args.task}_{args.model}"
######################################################################
+if not args.force_cpu and torch.cuda.is_available():
+ device = torch.device("cuda")
+ torch.backends.cuda.matmul.allow_tf32 = True
+else:
+ device = torch.device("cpu")
+
+######################################################################
+
default_task_args = {
"addition": {
"model": "352M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"byheart": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
"expr": {
"model": "352M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 2500000,
"nb_test_samples": 10000,
},
"grid": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"qmlp": {
"model": "37M",
- "batch_size": 10,
+ "physical_batch_size": 10,
"nb_train_samples": 100000,
"nb_test_samples": 1000,
},
"guessop": {
"model": "352M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 1000000,
"nb_test_samples": 10000,
},
"learnop": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
"maze": {
"model": "37M",
- "batch_size": 5,
+ "physical_batch_size": 5,
"nb_train_samples": 100000,
"nb_test_samples": 10000,
},
"picoclvr": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"rpl": {
"model": "352M",
- "batch_size": 5,
+ "physical_batch_size": 5,
"nb_train_samples": 2500000,
"nb_test_samples": 10000,
},
"snake": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"stack": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 100000,
"nb_test_samples": 1000,
},
"twotargets": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
"memory": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 25000,
"nb_test_samples": 10000,
},
"mixing": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"mnist": {
"model": "37M",
- "batch_size": 10,
+ "physical_batch_size": 5,
"nb_train_samples": 60000,
"nb_test_samples": 10000,
},
"dim_keys": 32,
"dim_hidden": 32,
"nb_heads": 2,
- "dim_rec_v": 16,
"nb_blocks": 2,
},
"17K-C": {
"nb_heads": 2,
"nb_lines": 16,
"caterpillar_height": 4,
- "dim_rec_v": 16,
"nb_blocks": 2,
},
"4M": {
"dim_keys": 32,
"dim_hidden": 1024,
"nb_heads": 4,
- "dim_rec_v": 64,
"nb_blocks": 6,
},
"4M-C": {
"nb_heads": 4,
"nb_lines": 32,
"caterpillar_height": 4,
- "dim_rec_v": 64, # dim_model / nb_heads
"nb_blocks": 6,
},
"37M": {
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
- "dim_rec_v": 64,
"nb_blocks": 12,
},
"37M-C": {
"nb_heads": 8,
"nb_lines": 256,
"caterpillar_height": 32,
- "dim_rec_v": 64,
"nb_blocks": 12,
},
"122M": {
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
- "dim_rec_v": 96,
"nb_blocks": 24,
},
"122M-C": {
"dim_hidden": 2048,
"nb_heads": 8,
"nb_lines": 128,
- "dim_rec_v": 96,
"nb_blocks": 24,
},
"352M": {
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
- "dim_rec_v": 128,
"nb_blocks": 48,
},
"352M-C": {
"dim_hidden": 2048,
"nb_heads": 8,
"nb_lines": 128,
- "dim_rec_v": 128,
"nb_blocks": 48,
},
}
try:
os.mkdir(args.result_dir)
except FileExistsError:
- if not args.overwrite_results:
+ if not args.continue_training:
print(f"result directory {args.result_dir} already exists")
exit(1)
+loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
+lambda_file = open(os.path.join(args.result_dir, "lambda.dat"), "a")
+
log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
if args.seed >= 0:
for n in vars(args):
log_string(f"args.{n} {getattr(args, n)}")
+for k, v in sup_args.items():
+ log_string(f'sup_args["{k}"] "{v}"')
+
######################################################################
if it < args.nb_warmup_iter:
return args.legacy_large_lr * it / args.nb_warmup_iter
- elif it < args.legacy_nb_epoch_large_lr:
+ elif n_epoch < args.legacy_nb_epoch_large_lr:
return args.legacy_large_lr
else:
return args.legacy_small_lr
######################################################################
+def add_memex_v1(batches, memex_proba, marker_token):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ t = (
+ torch.arange(1 + 2 * input.size(1), device=input.device)[None, :]
+ .expand(input.size(0), -1)
+ .clone()
+ )
+
+ u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
+ caterpillar_length = args.nb_lines // args.caterpillar_height
+ u1 = (
+ u0
+ + torch.randint(
+ caterpillar_length, (input.size(0), 1), device=input.device
+ )
+ + 1
+ )
+
+ m0 = (t < u0).long()
+ m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
+
+ t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
+ m = (t < 0).long()
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ new_input = input[n, t.clamp(min=0)]
+ new_input = (1 - m) * new_input + m * (marker_token)
+
+ memex_mask = new_input.new_zeros(new_input.size())
+ memex_mask[:, input.size(1) :] = 1.0
+
+ yield new_input, memex_mask
+
+ yield input
+
+
+# The marker token is not used for this one
+def add_memex_v2(batches, memex_proba, marker_token):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ t = torch.arange(input.size(1) // 4, device=input.device)[None, :].expand(
+ input.size(0), -1
+ )
+ t = t + torch.randint(
+ input.size(1) - t.size(1), (t.size(0), 1), device=t.device
+ )
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ flash = input[n, t]
+ new_input = torch.cat([input, flash], dim=1)
+
+ memex_mask = new_input.new_zeros(new_input.size())
+ memex_mask[:, input.size(1) :] = 1.0
+
+ yield new_input, memex_mask
+
+ else:
+ yield input
+
+
+def add_memex_v3(batches, memex_proba, marker_token):
+ for input in batches:
+ memex_len = input.size(1) // 8
+
+ t = torch.arange(input.size(1) + memex_len, device=input.device)[
+ None, :
+ ].expand(input.size(0), -1)
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ t = (t - 1).clamp(min=0)
+
+ # Call me the tensor-spaghetti master
+
+ trigger = torch.rand(t.size(), device=t.device)
+ trigger[:, -memex_len:] = 2.0
+ trigger[:, : memex_len + 1] = 2.0
+ trigger = (trigger == trigger.min(dim=1, keepdim=True).values).long()
+ memex_mask = trigger.clone()
+ memex_mask[:, memex_len:] -= trigger[:, :-memex_len]
+ memex_mask = memex_mask.cumsum(dim=1)
+
+ u = 1 - memex_mask
+ u[:, 0] = 0
+ u = u.cumsum(dim=1)
+
+ v = (
+ (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
+ + torch.randint(
+ input.size(1) - memex_len, (input.size(0), 1), device=t.device
+ )
+ ) * memex_mask
+ u = u * (1 - memex_mask) + v * memex_mask
+
+ new_input = input[n, u]
+ limits = trigger.clone()
+ limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
+ new_input = new_input * (1 - limits) + marker_token * limits
+ new_input[:, 0] = marker_token
+
+ orig = torch.cat(
+ [
+ input,
+ torch.full((input.size(0), memex_len), memex_marker, device=t.device),
+ ],
+ dim=1,
+ )
+
+ a = (torch.rand(input.size(0), 1, device=t.device) <= memex_proba).long()
+
+ new_input = (1 - a) * orig + a * new_input
+
+ yield new_input # memex_mask
+
+
+######################################################################
+
+assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+
+assert args.batch_size % args.physical_batch_size == 0
+
+
def picoclvr_pruner_horizontal_green(p):
return not ("green" in p and ("left" in p or "right" in p))
problem=problems.ProblemByHeart(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemLearnOperator(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemGuessOperator(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemTwoTargets(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemMemory(len_total=args.memory_len_total),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemAddition(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
task = tasks.PicoCLVR(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
height=args.picoclvr_height,
width=args.picoclvr_width,
nb_colors=args.picoclvr_nb_colors,
task = tasks.MNIST(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
device=device_data,
)
task = tasks.Maze(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
height=args.maze_height,
width=args.maze_width,
nb_walls=args.maze_nb_walls,
task = tasks.Snake(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
height=args.snake_height,
width=args.snake_width,
nb_colors=args.snake_nb_colors,
task = tasks.Stack(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
nb_steps=args.stack_nb_steps,
nb_stacks=args.stack_nb_stacks,
sequence_length=args.expr_sequence_length,
operand_max=args.expr_operand_max,
result_max=args.expr_result_max,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
device=device_data,
)
task = tasks.RPL(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
nb_starting_values=args.rpl_nb_starting_values,
max_input=args.rpl_max_input,
prog_len=args.rpl_prog_len,
task = tasks.Grid(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
size=args.grid_size,
+ nb_shapes=args.grid_nb_shapes,
+ nb_colors=args.grid_nb_colors,
logger=log_string,
device=device_data,
)
task = tasks.QMLP(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
result_dir=args.result_dir,
logger=log_string,
device=device_data,
vocabulary_size = task.vocabulary_size()
+if args.memex_proba > 0:
+ memex_marker = vocabulary_size
+ vocabulary_size += 1
+
log_string(f"vocabulary_size {vocabulary_size}")
##############################
nb_heads=args.nb_heads,
nb_lines=args.nb_lines,
caterpillar_height=args.caterpillar_height,
- dim_rec_v=args.dim_rec_v,
nb_blocks=args.nb_blocks,
causal=True,
dropout=args.dropout,
attention_layer=args.attention,
+ logger=log_string,
+ args=args,
)
model.to(device)
##############################
+if "calibrate" in sup_args:
+ for input in task.batches(split="train", desc="calibrate"):
+ input = input.to(device)
+ output = model(mygpt.BracketedSequence(input)).x
+
+ for n, m in model.named_modules():
+ for a in dir(m):
+ x = getattr(m, a)
+ if isinstance(x, mygpt.Calibrator):
+ print(f"####### ${n} | ${a} ########################")
+ mean, std = x.moments()
+ print("mean\n", mean, "\n")
+ print("std\n", std, "\n")
+ print(f"############################################\n\n")
+
+ exit(0)
+
+##############################
+
nb_samples_seen = 0
if nb_epochs_finished >= nb_epochs:
deterministic_synthesis=args.deterministic_synthesis,
)
-time_pred_result = None
+time_pred_result = datetime.datetime.now()
it = 0
+n_batch = 0
+
+
+def the_dot_products(value1, value2, params):
+ g1g1, g1g2, g2g2 = 0, 0, 0
+ for p in params:
+ g1 = torch.autograd.grad(value1, p, retain_graph=True)[0]
+ g2 = torch.autograd.grad(value2, p, retain_graph=True)[0]
+ g1g1 += g1.pow(2).sum()[None]
+ g2g2 += g2.pow(2).sum()[None]
+ g1g2 += (g1 * g2).sum()[None]
+ return torch.cat([g1g1, g1g2, g2g2])
+
+
+def update_ave_grad(value, params, name, eps=1e-3):
+ for p in params:
+ g = torch.autograd.grad(value, p, retain_graph=True)[0]
+ ag = getattr(p, name) if hasattr(p, name) else 0
+ setattr(p, name, (1 - eps) * ag + eps * g)
+
+
+def norm(params, name):
+ s = 0
+ for p in params:
+ s += getattr(p, name).pow(2).sum()
+ return s
+
+
for n_epoch in range(nb_epochs_finished, nb_epochs):
if args.optim == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
- for input in task.batches(split="train"):
- model.reset_inner_loss()
- input = input.to(device)
+ memex_proba = (
+ args.memex_proba
+ if args.memex_nb_epochs is None or n_epoch < args.memex_nb_epochs
+ else 0.0
+ )
- output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
- inner_loss = model.get_inner_loss()
+ log_string(f"memex_proba {memex_proba}")
+
+ if args.memex_proba > 0:
+ warnings.warn("memex v3", RuntimeWarning)
+ train_batches = add_memex_v3(
+ batches=task.batches(split="train"),
+ memex_proba=memex_proba,
+ marker_token=memex_marker,
+ )
+ else:
+ train_batches = task.batches(split="train")
- acc_train_loss += loss.item() * input.size(0)
- acc_train_inner_loss += inner_loss.item() * input.size(0)
+ def add_none(it):
+ for x in it:
+ yield x
+ yield None
- nb_train_samples += input.size(0)
- nb_samples_seen += input.size(0)
+ nb_acc_samples = 0
- total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
+ for input in add_none(train_batches):
+ if input is not None:
+ if type(input) is tuple:
+ input, memex_mask = input
+ memex_mask = memex_mask.to(device)
+ else:
+ memex_mask = None
- it += 1
- lr = get_lr(n_epoch, it)
- for param_group in optimizer.param_groups:
- param_group["lr"] = lr
+ model.reset_inner_loss()
+ input = input.to(device)
+
+ output = model(mygpt.BracketedSequence(input)).x
+
+ if memex_mask is None:
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ else:
+ loss = F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+ loss_regular = (loss * (1 - memex_mask)).mean()
+ loss_memex = (loss * memex_mask).mean()
+
+ if it < 100 or torch.rand(1) < 0.01:
+ update_ave_grad(loss_regular, model.parameters(), "grad_regular")
+ update_ave_grad(loss_memex, model.parameters(), "grad_memex")
+ norm_regular = norm(model.parameters(), "grad_regular")
+ norm_memex = norm(model.parameters(), "grad_memex")
+ l_memex = (
+ max(norm_regular, norm_memex) - norm_regular
+ ) / norm_memex
+
+ loss = loss_regular + l_memex * loss_memex
+
+ inner_loss = model.get_inner_loss()
- # log_string(f"learning_rate {lr}")
+ acc_train_loss += loss.item() * input.size(0)
+ acc_train_inner_loss += inner_loss.item() * input.size(0)
- optimizer.zero_grad()
- total_loss.backward()
- optimizer.step()
+ nb_train_samples += input.size(0)
+ nb_samples_seen += input.size(0)
+
+ total_loss = loss + (
+ args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
+ )
+
+ it += 1
+ lr = get_lr(n_epoch, it)
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+
+ # log_string(f"learning_rate {lr}")
+
+ total_loss.backward()
+ nb_acc_samples += input.size(0)
+
+ if (input is None and nb_acc_samples > 0) or nb_acc_samples == args.batch_size:
+ assert nb_acc_samples <= args.batch_size
+ optimizer.step()
+ grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+ loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+ if memex_mask is not None:
+ lambda_file.write(
+ f"{n_epoch} {n_batch} {l_memex} {norm_regular} {norm_memex}\n"
+ )
+ optimizer.zero_grad()
+ nb_acc_samples = 0
+ n_batch += 1
with torch.autograd.no_grad():
model.eval()
)
time_current_result = datetime.datetime.now()
- if time_pred_result is not None:
- log_string(
- f"next_result {time_current_result + (time_current_result - time_pred_result)}"
- )
+ log_string(
+ f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+ )
time_pred_result = time_current_result
checkpoint = {