from torch import nn
from torch.nn import functional as F
+# torch.autograd.set_detect_anomaly(True) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
import ffutils
import mygpt, tasks, problems
########################################
-parser.add_argument("--nb_epochs", type=int, default=50)
+parser.add_argument("--nb_epochs", type=int, default=25)
+
+parser.add_argument("--physical_batch_size", type=int, default=None)
-parser.add_argument("--batch_size", type=int, default=None)
+parser.add_argument("--batch_size", type=int, default=25)
parser.add_argument("--nb_train_samples", type=int, default=None)
parser.add_argument("--memex_proba", type=float, default=0)
-parser.add_argument("--memex_nb_epochs", type=float, default=1)
+parser.add_argument("--memex_nb_epochs", type=float, default=None)
parser.add_argument("--dim_model", type=int, default=None)
default_task_args = {
"addition": {
"model": "352M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"byheart": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
"expr": {
"model": "352M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 2500000,
"nb_test_samples": 10000,
},
"grid": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"qmlp": {
"model": "37M",
- "batch_size": 10,
+ "physical_batch_size": 10,
"nb_train_samples": 100000,
"nb_test_samples": 1000,
},
"guessop": {
"model": "352M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 1000000,
"nb_test_samples": 10000,
},
"learnop": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
"maze": {
"model": "37M",
- "batch_size": 5,
+ "physical_batch_size": 5,
"nb_train_samples": 100000,
"nb_test_samples": 10000,
},
"picoclvr": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"rpl": {
"model": "352M",
- "batch_size": 5,
+ "physical_batch_size": 5,
"nb_train_samples": 2500000,
"nb_test_samples": 10000,
},
"snake": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"stack": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 100000,
"nb_test_samples": 1000,
},
"twotargets": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
"memory": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 25000,
"nb_test_samples": 10000,
},
"mixing": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"mnist": {
"model": "37M",
- "batch_size": 10,
+ "physical_batch_size": 5,
"nb_train_samples": 60000,
"nb_test_samples": 10000,
},
######################################################################
+def add_memex_v2(batches, memex_proba, marker_token):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ t = (
+ torch.arange(1 + 2 * input.size(1), device=input.device)[None, :]
+ .expand(input.size(0), -1)
+ .clone()
+ )
+
+ u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
+ caterpillar_length = args.nb_lines // args.caterpillar_height
+ u1 = (
+ u0
+ + torch.randint(
+ caterpillar_length, (input.size(0), 1), device=input.device
+ )
+ + 1
+ )
+
+ m0 = (t < u0).long()
+ m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
+
+ t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
+ m = (t < 0).long()
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ new_input = input[n, t.clamp(min=0)]
+ new_input = (1 - m) * new_input + m * (marker_token)
+
+ yield new_input
+
+ yield input
+
+
+def add_memex_v3(batches, memex_proba, marker_token):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ t = (
+ torch.arange(2 * input.size(1), device=input.device)[None, :]
+ .expand(input.size(0), -1)
+ .clone()
+ )
+
+ u = torch.rand(t.size(), device=t.device)
+ u[:, : input.size(1)] = 1.0
+ memex_v3_proba_fragment = 1 / 20
+ u = (u < memex_v3_proba_fragment).long()
+ v = u * torch.randint(input.size(1), u.size())
+ u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[
+ :, : input.size(1) - 1
+ ] * input.size(1)
+ u = u.cumsum().clamp(min=0)
+
+ u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
+ caterpillar_length = args.nb_lines // args.caterpillar_height
+ u1 = (
+ u0
+ + torch.randint(
+ caterpillar_length, (input.size(0), 1), device=input.device
+ )
+ + 1
+ )
+
+ m0 = (t < u0).long()
+ m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
+
+ t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
+ m = (t < 0).long()
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ new_input = input[n, t.clamp(min=0)]
+ new_input = (1 - m) * new_input + m * (marker_token)
+
+ yield new_input
+
+ yield input
+
+
+######################################################################
+
assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
problem=problems.ProblemByHeart(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemLearnOperator(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemGuessOperator(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemTwoTargets(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemMemory(len_total=args.memory_len_total),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemAddition(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
task = tasks.PicoCLVR(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
height=args.picoclvr_height,
width=args.picoclvr_width,
nb_colors=args.picoclvr_nb_colors,
task = tasks.MNIST(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
device=device_data,
)
task = tasks.Maze(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
height=args.maze_height,
width=args.maze_width,
nb_walls=args.maze_nb_walls,
task = tasks.Snake(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
height=args.snake_height,
width=args.snake_width,
nb_colors=args.snake_nb_colors,
task = tasks.Stack(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
nb_steps=args.stack_nb_steps,
nb_stacks=args.stack_nb_stacks,
sequence_length=args.expr_sequence_length,
operand_max=args.expr_operand_max,
result_max=args.expr_result_max,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
device=device_data,
)
task = tasks.RPL(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
nb_starting_values=args.rpl_nb_starting_values,
max_input=args.rpl_max_input,
prog_len=args.rpl_prog_len,
task = tasks.Grid(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
size=args.grid_size,
nb_shapes=args.grid_nb_shapes,
nb_colors=args.grid_nb_colors,
task = tasks.QMLP(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
result_dir=args.result_dir,
logger=log_string,
device=device_data,
nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
- def add_memex(batches, memex_proba):
- for input in batches:
- if torch.rand(1).item() < memex_proba:
- sep = torch.full(
- (input.size(0), 1), vocabulary_size - 1, device=input.device
- )
+ memex_proba = (
+ args.memex_proba
+ if args.memex_nb_epochs is None or n_epoch < args.memex_nb_epochs
+ else 0.0
+ )
- yield torch.cat(
- [
- input,
- sep,
- input,
- ],
- dim=1,
- )
- yield input
+ log_string(f"memex_proba {memex_proba}")
- train_batches = add_memex(
- task.batches(split="train"),
- args.memex_proba if n_epoch < args.memex_nb_epochs else 0.0,
+ train_batches = add_memex_v2(
+ batches=task.batches(split="train"),
+ memex_proba=memex_proba,
+ marker_token=vocabulary_size - 1,
)
- for input in train_batches:
- model.reset_inner_loss()
- input = input.to(device)
+ def add_none(it):
+ for x in it:
+ yield x
+ yield None
- output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
- inner_loss = model.get_inner_loss()
+ nb_acc_samples = 0
- acc_train_loss += loss.item() * input.size(0)
- acc_train_inner_loss += inner_loss.item() * input.size(0)
+ for input in add_none(train_batches):
+ if input is not None:
+ model.reset_inner_loss()
+ input = input.to(device)
- nb_train_samples += input.size(0)
- nb_samples_seen += input.size(0)
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ inner_loss = model.get_inner_loss()
- total_loss = loss + (
- args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
- )
+ acc_train_loss += loss.item() * input.size(0)
+ acc_train_inner_loss += inner_loss.item() * input.size(0)
+
+ nb_train_samples += input.size(0)
+ nb_samples_seen += input.size(0)
- it += 1
- lr = get_lr(n_epoch, it)
- for param_group in optimizer.param_groups:
- param_group["lr"] = lr
+ total_loss = loss + (
+ args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
+ )
- # log_string(f"learning_rate {lr}")
+ it += 1
+ lr = get_lr(n_epoch, it)
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
- optimizer.zero_grad()
- total_loss.backward()
- optimizer.step()
+ # log_string(f"learning_rate {lr}")
- grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+ total_loss.backward()
+ nb_acc_samples += input.size(0)
- loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+ if (input is None and nb_acc_samples > 0) or nb_acc_samples == args.batch_size:
+ assert nb_acc_samples <= args.batch_size
+ optimizer.step()
+ grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+ loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+ optimizer.zero_grad()
+ nb_acc_samples = 0
n_batch += 1