from torch import nn
from torch.nn import functional as F
+# torch.autograd.set_detect_anomaly(True) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
import ffutils
import mygpt, tasks, problems
########################################
-parser.add_argument("--nb_epochs", type=int, default=50)
+parser.add_argument("--nb_epochs", type=int, default=25)
+
+parser.add_argument("--physical_batch_size", type=int, default=None)
-parser.add_argument("--batch_size", type=int, default=None)
+parser.add_argument("--batch_size", type=int, default=25)
parser.add_argument("--nb_train_samples", type=int, default=None)
parser.add_argument("--memex_proba", type=float, default=0)
-parser.add_argument("--memex_nb_epochs", type=float, default=1)
+parser.add_argument("--memex_nb_epochs", type=float, default=None)
parser.add_argument("--dim_model", type=int, default=None)
default_task_args = {
"addition": {
"model": "352M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"byheart": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
"expr": {
"model": "352M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 2500000,
"nb_test_samples": 10000,
},
"grid": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"qmlp": {
"model": "37M",
- "batch_size": 10,
+ "physical_batch_size": 10,
"nb_train_samples": 100000,
"nb_test_samples": 1000,
},
"guessop": {
"model": "352M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 1000000,
"nb_test_samples": 10000,
},
"learnop": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
"maze": {
"model": "37M",
- "batch_size": 5,
+ "physical_batch_size": 5,
"nb_train_samples": 100000,
"nb_test_samples": 10000,
},
"picoclvr": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"rpl": {
"model": "352M",
- "batch_size": 5,
+ "physical_batch_size": 5,
"nb_train_samples": 2500000,
"nb_test_samples": 10000,
},
"snake": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"stack": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 100000,
"nb_test_samples": 1000,
},
"twotargets": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
"memory": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 25000,
"nb_test_samples": 10000,
},
"mixing": {
"model": "37M",
- "batch_size": 25,
+ "physical_batch_size": 25,
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
"mnist": {
"model": "37M",
- "batch_size": 10,
+ "physical_batch_size": 5,
"nb_train_samples": 60000,
"nb_test_samples": 10000,
},
######################################################################
+def add_memex_v2(batches, memex_proba, marker_token):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ t = (
+ torch.arange(1 + 2 * input.size(1), device=input.device)[None, :]
+ .expand(input.size(0), -1)
+ .clone()
+ )
+
+ u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
+ caterpillar_length = args.nb_lines // args.caterpillar_height
+ u1 = (
+ u0
+ + torch.randint(
+ caterpillar_length, (input.size(0), 1), device=input.device
+ )
+ + 1
+ )
+
+ m0 = (t < u0).long()
+ m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
+
+ t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
+ m = (t < 0).long()
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ new_input = input[n, t.clamp(min=0)]
+ new_input = (1 - m) * new_input + m * (marker_token)
+
+ yield new_input
+
+ yield input
+
+
+def add_memex_v3(batches, memex_proba, marker_token):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ t = (
+ torch.arange(2 * input.size(1), device=input.device)[None, :]
+ .expand(input.size(0), -1)
+ .clone()
+ )
+
+ u = torch.rand(t.size(), device=t.device)
+ u[:, : input.size(1)] = 1.0
+ memex_v3_proba_fragment = 1 / 20
+ u = (u < memex_v3_proba_fragment).long()
+ v = u * torch.randint(input.size(1), u.size())
+ u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[
+ :, : input.size(1) - 1
+ ] * input.size(1)
+ u = u.cumsum().clamp(min=0)
+
+ u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device)
+ caterpillar_length = args.nb_lines // args.caterpillar_height
+ u1 = (
+ u0
+ + torch.randint(
+ caterpillar_length, (input.size(0), 1), device=input.device
+ )
+ + 1
+ )
+
+ m0 = (t < u0).long()
+ m1 = (t >= u1).long() * (t < u1 + input.size(1)).long()
+
+ t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1
+ m = (t < 0).long()
+ n = torch.arange(input.size(0), device=input.device)[:, None].expand(
+ -1, t.size(1)
+ )
+
+ new_input = input[n, t.clamp(min=0)]
+ new_input = (1 - m) * new_input + m * (marker_token)
+
+ yield new_input
+
+ yield input
+
+
+######################################################################
+
assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
problem=problems.ProblemByHeart(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemLearnOperator(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemGuessOperator(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemTwoTargets(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemMemory(len_total=args.memory_len_total),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
problem=problems.ProblemAddition(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
device=device_data,
)
task = tasks.PicoCLVR(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
height=args.picoclvr_height,
width=args.picoclvr_width,
nb_colors=args.picoclvr_nb_colors,
task = tasks.MNIST(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
device=device_data,
)
task = tasks.Maze(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
height=args.maze_height,
width=args.maze_width,
nb_walls=args.maze_nb_walls,
task = tasks.Snake(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
height=args.snake_height,
width=args.snake_width,
nb_colors=args.snake_nb_colors,
task = tasks.Stack(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
logger=log_string,
nb_steps=args.stack_nb_steps,
nb_stacks=args.stack_nb_stacks,
sequence_length=args.expr_sequence_length,
operand_max=args.expr_operand_max,
result_max=args.expr_result_max,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
device=device_data,
)
task = tasks.RPL(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
nb_starting_values=args.rpl_nb_starting_values,
max_input=args.rpl_max_input,
prog_len=args.rpl_prog_len,
task = tasks.Grid(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
size=args.grid_size,
nb_shapes=args.grid_nb_shapes,
nb_colors=args.grid_nb_colors,
task = tasks.QMLP(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
+ batch_size=args.physical_batch_size,
result_dir=args.result_dir,
logger=log_string,
device=device_data,
nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
- def add_memex(batches, memex_proba):
- for input in batches:
- if torch.rand(1).item() < memex_proba:
- sep = torch.full(
- (input.size(0), 1), vocabulary_size - 1, device=input.device
- )
+ memex_proba = (
+ args.memex_proba
+ if args.memex_nb_epochs is None or n_epoch < args.memex_nb_epochs
+ else 0.0
+ )
- yield torch.cat(
- [
- input,
- sep,
- input,
- ],
- dim=1,
- )
- yield input
+ log_string(f"memex_proba {memex_proba}")
- train_batches = add_memex(
- task.batches(split="train"),
- args.memex_proba if n_epoch < args.memex_nb_epochs else 0.0,
+ train_batches = add_memex_v2(
+ batches=task.batches(split="train"),
+ memex_proba=memex_proba,
+ marker_token=vocabulary_size - 1,
)
- for input in train_batches:
- model.reset_inner_loss()
- input = input.to(device)
+ def add_none(it):
+ for x in it:
+ yield x
+ yield None
- output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
- inner_loss = model.get_inner_loss()
+ nb_acc_samples = 0
- acc_train_loss += loss.item() * input.size(0)
- acc_train_inner_loss += inner_loss.item() * input.size(0)
+ for input in add_none(train_batches):
+ if input is not None:
+ model.reset_inner_loss()
+ input = input.to(device)
- nb_train_samples += input.size(0)
- nb_samples_seen += input.size(0)
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ inner_loss = model.get_inner_loss()
- total_loss = loss + (
- args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
- )
+ acc_train_loss += loss.item() * input.size(0)
+ acc_train_inner_loss += inner_loss.item() * input.size(0)
+
+ nb_train_samples += input.size(0)
+ nb_samples_seen += input.size(0)
- it += 1
- lr = get_lr(n_epoch, it)
- for param_group in optimizer.param_groups:
- param_group["lr"] = lr
+ total_loss = loss + (
+ args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
+ )
- # log_string(f"learning_rate {lr}")
+ it += 1
+ lr = get_lr(n_epoch, it)
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
- optimizer.zero_grad()
- total_loss.backward()
- optimizer.step()
+ # log_string(f"learning_rate {lr}")
- grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+ total_loss.backward()
+ nb_acc_samples += input.size(0)
- loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+ if (input is None and nb_acc_samples > 0) or nb_acc_samples == args.batch_size:
+ assert nb_acc_samples <= args.batch_size
+ optimizer.step()
+ grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+ loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+ optimizer.zero_grad()
+ nb_acc_samples = 0
n_batch += 1
##############################
+class NaNChecker(nn.Module):
+ def __init__(self, name):
+ super().__init__()
+ self.name = name
+
+ def forward(self, bs):
+ x = bs.x if type(bs) is BracketedSequence else bs
+ assert not x.isnan().any(), f"${self.name} detected NaN"
+ assert not x.isinf().any(), f"${self.name} detected Inf"
+ return bs
+
+
class WithResidual(nn.Module):
def __init__(self, *f):
super().__init__()
self.w_qw = randw(nb_heads, dim_qk, dim_model)
self.w_qr = randw(nb_heads, dim_qk, dim_model)
- # self.w_k = randw(nb_heads, dim_qk, dim_model)
self.w_v = randw(nb_heads, dim_v, dim_model)
self.w_o = randw(dim_v * nb_heads, dim_model)
- def reset_inner_loss(self):
- self.acc_attention = 0
- self.acc_nb = 0
-
- def get_inner_loss(self):
- warnings.warn("l2 regularization", RuntimeWarning)
- return (self.acc_attention / self.acc_nb).pow(2).sum()
- # return torch.tensor([0], device=self.w_qw.device)
-
def forward(self, bs):
x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb
self.rec_v = x_q.new_zeros(
x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1)
)
- # self.rec_k = x_q.new_zeros(
- # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1)
- # )
self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
- ######################################################################
- # Prepare the keys
-
- k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
-
- warnings.warn("rotating key barrel", RuntimeWarning)
- k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
- t_barrel = torch.arange(t0, t1, device=k_star.device)
- t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
- l_barrel = (
- torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
- ) % k_star.size(0)
- k_star = k_star[l_barrel, t_barrel]
-
######################################################################
# Compute the recurrent state
qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw)
v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v)
- # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k)
- aw = torch.einsum(
- "nhtd,ltd->nhlt",
- qw,
- k_star,
- ) / math.sqrt(self.w_qw.size(1))
+ aw = torch.einsum("nhtd,ld->nhlt", qw, self.k_star) / math.sqrt(
+ self.w_qw.size(1)
+ )
aw = aw.softmax(dim=2) # nhlt
- if self.train:
- self.acc_attention += aw.sum(dim=(0, 1, 3))
- self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3)
-
aw = F.dropout(aw, self.attention_dropout, self.training)
A = 1 - aw.sum(dim=1) # nlt
V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous()
- # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous()
if t0 == 0:
V0 = None
- # K0 = None
else:
V0 = self.rec_v[:, :, t0 - 1]
- # K0 = self.rec_k[:, :, t0 - 1]
self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0)
- # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0)
######################################################################
# compute the readout
ar = torch.einsum(
"nhtd,ld->nhlt",
qr,
- # self.rec_k[:, :, t0:t1],
self.k_star,
) / math.sqrt(self.w_qr.size(1))
self.acc_nb = 0
def get_inner_loss(self):
- warnings.warn("l2 regularization", RuntimeWarning)
- return (self.acc_attention / self.acc_nb).pow(2).sum()
- # return torch.tensor([0], device=self.w_qw.device)
+ # warnings.warn("l2 regularization", RuntimeWarning)
+ # return (self.acc_attention / self.acc_nb).pow(2).sum()
+ return torch.tensor([0], device=self.w_qw.device)
# warnings.warn("side regularization", RuntimeWarning)
# return (
# (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum()
k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1)
- warnings.warn("rotating key barrel", RuntimeWarning)
+ # warnings.warn("rotating key barrel", RuntimeWarning)
k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1)
t_barrel = torch.arange(t0, t1, device=k_star.device)
t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0)
l_barrel = (
- torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel
+ torch.arange(k_star.size(0), device=k_star.device)[:, None] # + t_barrel
) % k_star.size(0)
k_star = k_star[l_barrel, t_barrel]
):
super().__init__()
+ self.vocabulary_size = vocabulary_size
+
assert attention_layer in {
"mha",
"dumbrec",
)
for t in range(nb_steps):
- op = torch.randint(2, (nb,))
- st = torch.randint(nb_stacks, (nb,))
- op = op * (stack_counts[k, st] > 0)
- if values is None:
+ op = torch.randint(2, (nb,)) # what operation (push/pop)
+ st = torch.randint(nb_stacks, (nb,)) # on what stack
+ op = op * (stack_counts[k, st] > 0) # can only push is stack is empty
+
+ if values is None: # we can use all the values
val_push = torch.randint(10**nb_digits, (nb,))
- else:
+ else: # values are constrained (e.g. to have train/test values disjoint)
val_push = values[torch.randint(values.size(0), (nb,))]
- val_pop = stack[
+
+ val_pop = stack[ # if we were popping, what value would that be?
k,
st,
- (stack_counts[k, st] - 1).clamp(min=0),
+ (stack_counts[k, st] - 1).clamp(min=0), # deal with empty stack
]
+
+ # we always push the value, but it will be lost if we pop
+ # since we will move the count down
stack[k, st, stack_counts[k, st]] = val_push
recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
+
+ # we increase the stack count only when we actually push
stack_counts[k[op == 0], st[op == 0]] += 1
stack_counts[k[op == 1], st[op == 1]] -= 1
+
+ # add the operation number to the sequence, that incude the stack number
result[:, (1 + nb_digits) * t] = st * 2 + op
+
+ # add the digits to the sequence
for d in range(nb_digits):
result[:, (1 + nb_digits) * t + 1 + d] = (
(op * val_pop + (1 - op) * val_push) // (10**d)
seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:]
-def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
- assert seq.size(0) % (1 + nb_digits) == 0
- s = ""
- for t in range(seq.size(0) // (1 + nb_digits)):
- n_op = seq[(1 + nb_digits) * t]
- if t > 0:
- s += " "
- if recorded_stack_counts is not None:
- s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] "
- s += f"POP" if n_op % 2 == 1 else f"PSH"
- if nb_stacks > 1:
- s += f"_{n_op//2}"
- for d in range(nb_digits):
- if seq[(1 + nb_digits) * t + 1 + d] == -1:
- s += " ?"
- else:
- s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
- return s
+def seq_to_str(seq, nb_stacks, nb_digits):
+ def n_to_str(n):
+ if n < 0:
+ return "?"
+ elif n < 2 * nb_stacks:
+ s = f"POP" if n % 2 == 1 else f"PSH"
+ if nb_stacks > 1:
+ s += f"_{n//2}"
+ return s
+ elif n < 2 * nb_stacks + 10:
+ return f"{n - 2 * nb_stacks}"
+ else:
+ return "#"
+
+ return " ".join([n_to_str(x.item()) for x in seq])
######################################################################
if __name__ == "__main__":
+ seq, recorded_stack_counts = generate_sequences(
+ nb=3,
+ nb_steps=6,
+ nb_stacks=3,
+ nb_digits=3,
+ )
+
+ sep = torch.full((seq.size(0), 1), seq.max() + 1)
+
+ seq = torch.cat([seq, sep, seq], dim=1)
+
+ for n in range(min(10, seq.size(0))):
+ print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3))
+
+ remove_popped_values(seq, 3, 3)
+
+ print()
+
+ for n in range(min(10, seq.size(0))):
+ print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3))
+
+ exit(0)
+
nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
seq, recorded_stack_counts = generate_sequences(
nb=nb,
print("-- PREPARED FOR TEST -----------------")
+ print("SANITY", seq.size())
+
remove_popped_values(seq, nb_stacks, nb_digits)
for n in range(min(10, seq.size(0))):