From 8012a611e9920816fe6ba382b69305242136bc2a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 15 Feb 2024 23:10:17 +0100 Subject: [PATCH] Update. --- fridge | 19 +++++ main.py | 243 ++++++++++++++++++++++++++++++++++++++----------------- maze.py | 9 ++- mygpt.py | 69 +++++----------- stack.py | 83 +++++++++++++------ tasks.py | 29 +++++-- 6 files changed, 295 insertions(+), 157 deletions(-) diff --git a/fridge b/fridge index 82d2b17..143092c 100644 --- a/fridge +++ b/fridge @@ -316,3 +316,22 @@ class Calibrator: if isinstance(m, mygpt.Caterpillar): + +###################################################################### + +2024 Feb 13 22:53:52 (from mygpt.py) + + ###################################################################### + # Prepare the keys + + k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1) + + warnings.warn("rotating key barrel", RuntimeWarning) + k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1) + t_barrel = torch.arange(t0, t1, device=k_star.device) + t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0) + l_barrel = ( + torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel + ) % k_star.size(0) + k_star = k_star[l_barrel, t_barrel] + diff --git a/main.py b/main.py index d6845e8..6254807 100755 --- a/main.py +++ b/main.py @@ -11,6 +11,8 @@ import torch, torchvision from torch import nn from torch.nn import functional as F +# torch.autograd.set_detect_anomaly(True) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + import ffutils import mygpt, tasks, problems @@ -51,9 +53,11 @@ parser.add_argument("--force_cpu", type=str2bool, default=False) ######################################## -parser.add_argument("--nb_epochs", type=int, default=50) +parser.add_argument("--nb_epochs", type=int, default=25) + +parser.add_argument("--physical_batch_size", type=int, default=None) -parser.add_argument("--batch_size", type=int, default=None) +parser.add_argument("--batch_size", type=int, default=25) parser.add_argument("--nb_train_samples", type=int, default=None) @@ -89,7 +93,7 @@ parser.add_argument("--attention", type=str, default=None) parser.add_argument("--memex_proba", type=float, default=0) -parser.add_argument("--memex_nb_epochs", type=float, default=1) +parser.add_argument("--memex_nb_epochs", type=float, default=None) parser.add_argument("--dim_model", type=int, default=None) @@ -238,97 +242,97 @@ else: default_task_args = { "addition": { "model": "352M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "byheart": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 50000, "nb_test_samples": 10000, }, "expr": { "model": "352M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 2500000, "nb_test_samples": 10000, }, "grid": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "qmlp": { "model": "37M", - "batch_size": 10, + "physical_batch_size": 10, "nb_train_samples": 100000, "nb_test_samples": 1000, }, "guessop": { "model": "352M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 1000000, "nb_test_samples": 10000, }, "learnop": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 50000, "nb_test_samples": 10000, }, "maze": { "model": "37M", - "batch_size": 5, + "physical_batch_size": 5, "nb_train_samples": 100000, "nb_test_samples": 10000, }, "picoclvr": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "rpl": { "model": "352M", - "batch_size": 5, + "physical_batch_size": 5, "nb_train_samples": 2500000, "nb_test_samples": 10000, }, "snake": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "stack": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 100000, "nb_test_samples": 1000, }, "twotargets": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 50000, "nb_test_samples": 10000, }, "memory": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 25000, "nb_test_samples": 10000, }, "mixing": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "mnist": { "model": "37M", - "batch_size": 10, + "physical_batch_size": 5, "nb_train_samples": 60000, "nb_test_samples": 10000, }, @@ -526,6 +530,90 @@ def get_lr(n_epoch, it): ###################################################################### +def add_memex_v2(batches, memex_proba, marker_token): + for input in batches: + if torch.rand(1).item() < memex_proba: + t = ( + torch.arange(1 + 2 * input.size(1), device=input.device)[None, :] + .expand(input.size(0), -1) + .clone() + ) + + u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device) + caterpillar_length = args.nb_lines // args.caterpillar_height + u1 = ( + u0 + + torch.randint( + caterpillar_length, (input.size(0), 1), device=input.device + ) + + 1 + ) + + m0 = (t < u0).long() + m1 = (t >= u1).long() * (t < u1 + input.size(1)).long() + + t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1 + m = (t < 0).long() + n = torch.arange(input.size(0), device=input.device)[:, None].expand( + -1, t.size(1) + ) + + new_input = input[n, t.clamp(min=0)] + new_input = (1 - m) * new_input + m * (marker_token) + + yield new_input + + yield input + + +def add_memex_v3(batches, memex_proba, marker_token): + for input in batches: + if torch.rand(1).item() < memex_proba: + t = ( + torch.arange(2 * input.size(1), device=input.device)[None, :] + .expand(input.size(0), -1) + .clone() + ) + + u = torch.rand(t.size(), device=t.device) + u[:, : input.size(1)] = 1.0 + memex_v3_proba_fragment = 1 / 20 + u = (u < memex_v3_proba_fragment).long() + v = u * torch.randint(input.size(1), u.size()) + u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[ + :, : input.size(1) - 1 + ] * input.size(1) + u = u.cumsum().clamp(min=0) + + u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device) + caterpillar_length = args.nb_lines // args.caterpillar_height + u1 = ( + u0 + + torch.randint( + caterpillar_length, (input.size(0), 1), device=input.device + ) + + 1 + ) + + m0 = (t < u0).long() + m1 = (t >= u1).long() * (t < u1 + input.size(1)).long() + + t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1 + m = (t < 0).long() + n = torch.arange(input.size(0), device=input.device)[:, None].expand( + -1, t.size(1) + ) + + new_input = input[n, t.clamp(min=0)] + new_input = (1 - m) * new_input + m * (marker_token) + + yield new_input + + yield input + + +###################################################################### + assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} @@ -554,7 +642,7 @@ if args.task == "byheart": problem=problems.ProblemByHeart(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -565,7 +653,7 @@ elif args.task == "learnop": problem=problems.ProblemLearnOperator(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -576,7 +664,7 @@ elif args.task == "guessop": problem=problems.ProblemGuessOperator(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -587,7 +675,7 @@ elif args.task == "twotargets": problem=problems.ProblemTwoTargets(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -597,7 +685,7 @@ elif args.task == "memory": problem=problems.ProblemMemory(len_total=args.memory_len_total), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -609,7 +697,7 @@ elif args.task == "mixing": ), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -619,7 +707,7 @@ elif args.task == "addition": problem=problems.ProblemAddition(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -628,7 +716,7 @@ elif args.task == "picoclvr": task = tasks.PicoCLVR( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.picoclvr_height, width=args.picoclvr_width, nb_colors=args.picoclvr_nb_colors, @@ -642,7 +730,7 @@ elif args.task == "mnist": task = tasks.MNIST( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, device=device_data, ) @@ -650,7 +738,7 @@ elif args.task == "maze": task = tasks.Maze( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.maze_height, width=args.maze_width, nb_walls=args.maze_nb_walls, @@ -661,7 +749,7 @@ elif args.task == "snake": task = tasks.Snake( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.snake_height, width=args.snake_width, nb_colors=args.snake_nb_colors, @@ -674,7 +762,7 @@ elif args.task == "stack": task = tasks.Stack( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, nb_steps=args.stack_nb_steps, nb_stacks=args.stack_nb_stacks, @@ -691,7 +779,7 @@ elif args.task == "expr": sequence_length=args.expr_sequence_length, operand_max=args.expr_operand_max, result_max=args.expr_result_max, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, device=device_data, ) @@ -699,7 +787,7 @@ elif args.task == "rpl": task = tasks.RPL( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, nb_starting_values=args.rpl_nb_starting_values, max_input=args.rpl_max_input, prog_len=args.rpl_prog_len, @@ -713,7 +801,7 @@ elif args.task == "grid": task = tasks.Grid( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, size=args.grid_size, nb_shapes=args.grid_nb_shapes, nb_colors=args.grid_nb_colors, @@ -725,7 +813,7 @@ elif args.task == "qmlp": task = tasks.QMLP( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, result_dir=args.result_dir, logger=log_string, device=device_data, @@ -904,60 +992,63 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0 - def add_memex(batches, memex_proba): - for input in batches: - if torch.rand(1).item() < memex_proba: - sep = torch.full( - (input.size(0), 1), vocabulary_size - 1, device=input.device - ) + memex_proba = ( + args.memex_proba + if args.memex_nb_epochs is None or n_epoch < args.memex_nb_epochs + else 0.0 + ) - yield torch.cat( - [ - input, - sep, - input, - ], - dim=1, - ) - yield input + log_string(f"memex_proba {memex_proba}") - train_batches = add_memex( - task.batches(split="train"), - args.memex_proba if n_epoch < args.memex_nb_epochs else 0.0, + train_batches = add_memex_v2( + batches=task.batches(split="train"), + memex_proba=memex_proba, + marker_token=vocabulary_size - 1, ) - for input in train_batches: - model.reset_inner_loss() - input = input.to(device) + def add_none(it): + for x in it: + yield x + yield None - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) - inner_loss = model.get_inner_loss() + nb_acc_samples = 0 - acc_train_loss += loss.item() * input.size(0) - acc_train_inner_loss += inner_loss.item() * input.size(0) + for input in add_none(train_batches): + if input is not None: + model.reset_inner_loss() + input = input.to(device) - nb_train_samples += input.size(0) - nb_samples_seen += input.size(0) + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + inner_loss = model.get_inner_loss() - total_loss = loss + ( - args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0 - ) + acc_train_loss += loss.item() * input.size(0) + acc_train_inner_loss += inner_loss.item() * input.size(0) + + nb_train_samples += input.size(0) + nb_samples_seen += input.size(0) - it += 1 - lr = get_lr(n_epoch, it) - for param_group in optimizer.param_groups: - param_group["lr"] = lr + total_loss = loss + ( + args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0 + ) - # log_string(f"learning_rate {lr}") + it += 1 + lr = get_lr(n_epoch, it) + for param_group in optimizer.param_groups: + param_group["lr"] = lr - optimizer.zero_grad() - total_loss.backward() - optimizer.step() + # log_string(f"learning_rate {lr}") - grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt() + total_loss.backward() + nb_acc_samples += input.size(0) - loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n") + if (input is None and nb_acc_samples > 0) or nb_acc_samples == args.batch_size: + assert nb_acc_samples <= args.batch_size + optimizer.step() + grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt() + loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n") + optimizer.zero_grad() + nb_acc_samples = 0 n_batch += 1 diff --git a/maze.py b/maze.py index 8ac9fce..4953d10 100755 --- a/maze.py +++ b/maze.py @@ -231,9 +231,14 @@ def save_image( [0, 255, 0], # start [127, 127, 255], # goal [255, 0, 0], # path + [128, 128, 128], # error ] ) + def safe_colors(x): + m = (x >= 0).long() * (x < colors.size(0) - 1).long() + return colors[x * m + (colors.size(0) - 1) * (1 - m)] + mazes = mazes.cpu() c_mazes = ( @@ -256,7 +261,7 @@ def save_image( if predicted_paths is not None: predicted_paths = predicted_paths.cpu() c_predicted_paths = ( - colors[predicted_paths.reshape(-1)] + safe_colors(predicted_paths.reshape(-1)) .reshape(predicted_paths.size() + (-1,)) .permute(0, 3, 1, 2) ) @@ -282,8 +287,6 @@ def save_image( -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4)) ).clone() - print(f"{img.size()=} {imgs.size()=}") - for k in range(imgs.size(1)): img[ :, diff --git a/mygpt.py b/mygpt.py index c833012..12b3631 100755 --- a/mygpt.py +++ b/mygpt.py @@ -86,6 +86,18 @@ class CacheWrapper(nn.Module): ############################## +class NaNChecker(nn.Module): + def __init__(self, name): + super().__init__() + self.name = name + + def forward(self, bs): + x = bs.x if type(bs) is BracketedSequence else bs + assert not x.isnan().any(), f"${self.name} detected NaN" + assert not x.isinf().any(), f"${self.name} detected Inf" + return bs + + class WithResidual(nn.Module): def __init__(self, *f): super().__init__() @@ -218,19 +230,9 @@ class DumbRec(nn.Module): self.w_qw = randw(nb_heads, dim_qk, dim_model) self.w_qr = randw(nb_heads, dim_qk, dim_model) - # self.w_k = randw(nb_heads, dim_qk, dim_model) self.w_v = randw(nb_heads, dim_v, dim_model) self.w_o = randw(dim_v * nb_heads, dim_model) - def reset_inner_loss(self): - self.acc_attention = 0 - self.acc_nb = 0 - - def get_inner_loss(self): - warnings.warn("l2 regularization", RuntimeWarning) - return (self.acc_attention / self.acc_nb).pow(2).sum() - # return torch.tensor([0], device=self.w_qw.device) - def forward(self, bs): x_q, t0, t1 = bs.x, bs.first, bs.first + bs.nb @@ -238,61 +240,33 @@ class DumbRec(nn.Module): self.rec_v = x_q.new_zeros( x_q.size(0), self.nb_lines, x_q.size(1), self.w_v.size(1) ) - # self.rec_k = x_q.new_zeros( - # x_q.size(0), self.nb_lines, x_q.size(1), self.w_k.size(1) - # ) self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1)) - ###################################################################### - # Prepare the keys - - k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1) - - warnings.warn("rotating key barrel", RuntimeWarning) - k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1) - t_barrel = torch.arange(t0, t1, device=k_star.device) - t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0) - l_barrel = ( - torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel - ) % k_star.size(0) - k_star = k_star[l_barrel, t_barrel] - ###################################################################### # Compute the recurrent state qw = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_qw) v = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_v) - # k = torch.einsum("ntc,hdc->nhtd", x_q[:, t0:t1], self.w_k) - aw = torch.einsum( - "nhtd,ltd->nhlt", - qw, - k_star, - ) / math.sqrt(self.w_qw.size(1)) + aw = torch.einsum("nhtd,ld->nhlt", qw, self.k_star) / math.sqrt( + self.w_qw.size(1) + ) aw = aw.softmax(dim=2) # nhlt - if self.train: - self.acc_attention += aw.sum(dim=(0, 1, 3)) - self.acc_nb += aw.size(0) * aw.size(1) * aw.size(3) - aw = F.dropout(aw, self.attention_dropout, self.training) A = 1 - aw.sum(dim=1) # nlt V = torch.einsum("nhlt,nhtd->nltd", aw, v).contiguous() - # K = torch.einsum("nhlt,nhtd->nltd", aw, k).contiguous() if t0 == 0: V0 = None - # K0 = None else: V0 = self.rec_v[:, :, t0 - 1] - # K0 = self.rec_k[:, :, t0 - 1] self.rec_v[:, :, t0:t1] = pscan_shape(A, V, V0) - # self.rec_k[:, :, t0:t1] = pscan_shape(A, K, K0) ###################################################################### # compute the readout @@ -302,7 +276,6 @@ class DumbRec(nn.Module): ar = torch.einsum( "nhtd,ld->nhlt", qr, - # self.rec_k[:, :, t0:t1], self.k_star, ) / math.sqrt(self.w_qr.size(1)) @@ -358,9 +331,9 @@ class KVRec(nn.Module): self.acc_nb = 0 def get_inner_loss(self): - warnings.warn("l2 regularization", RuntimeWarning) - return (self.acc_attention / self.acc_nb).pow(2).sum() - # return torch.tensor([0], device=self.w_qw.device) + # warnings.warn("l2 regularization", RuntimeWarning) + # return (self.acc_attention / self.acc_nb).pow(2).sum() + return torch.tensor([0], device=self.w_qw.device) # warnings.warn("side regularization", RuntimeWarning) # return ( # (0.5 / self.nb_lines - self.acc_attention / self.acc_nb).clamp(min=0).sum() @@ -384,12 +357,12 @@ class KVRec(nn.Module): k_star = self.k_star[:, None, :].expand(-1, t1 - t0, -1) - warnings.warn("rotating key barrel", RuntimeWarning) + # warnings.warn("rotating key barrel", RuntimeWarning) k_star = self.k_star[:, None, :].expand(-1, x_q.size(1), -1) t_barrel = torch.arange(t0, t1, device=k_star.device) t_barrel = t_barrel[None, :].expand(k_star.size(0), t1 - t0) l_barrel = ( - torch.arange(k_star.size(0), device=k_star.device)[:, None] + t_barrel + torch.arange(k_star.size(0), device=k_star.device)[:, None] # + t_barrel ) % k_star.size(0) k_star = k_star[l_barrel, t_barrel] @@ -781,6 +754,8 @@ class MyGPT(nn.Module): ): super().__init__() + self.vocabulary_size = vocabulary_size + assert attention_layer in { "mha", "dumbrec", diff --git a/stack.py b/stack.py index 543f04e..69a696d 100755 --- a/stack.py +++ b/stack.py @@ -25,23 +25,34 @@ def generate_sequences( ) for t in range(nb_steps): - op = torch.randint(2, (nb,)) - st = torch.randint(nb_stacks, (nb,)) - op = op * (stack_counts[k, st] > 0) - if values is None: + op = torch.randint(2, (nb,)) # what operation (push/pop) + st = torch.randint(nb_stacks, (nb,)) # on what stack + op = op * (stack_counts[k, st] > 0) # can only push is stack is empty + + if values is None: # we can use all the values val_push = torch.randint(10**nb_digits, (nb,)) - else: + else: # values are constrained (e.g. to have train/test values disjoint) val_push = values[torch.randint(values.size(0), (nb,))] - val_pop = stack[ + + val_pop = stack[ # if we were popping, what value would that be? k, st, - (stack_counts[k, st] - 1).clamp(min=0), + (stack_counts[k, st] - 1).clamp(min=0), # deal with empty stack ] + + # we always push the value, but it will be lost if we pop + # since we will move the count down stack[k, st, stack_counts[k, st]] = val_push recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st] + + # we increase the stack count only when we actually push stack_counts[k[op == 0], st[op == 0]] += 1 stack_counts[k[op == 1], st[op == 1]] -= 1 + + # add the operation number to the sequence, that incude the stack number result[:, (1 + nb_digits) * t] = st * 2 + op + + # add the digits to the sequence for d in range(nb_digits): result[:, (1 + nb_digits) * t + 1 + d] = ( (op * val_pop + (1 - op) * val_push) // (10**d) @@ -57,29 +68,49 @@ def remove_popped_values(seq, nb_stacks, nb_digits): seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:] -def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None): - assert seq.size(0) % (1 + nb_digits) == 0 - s = "" - for t in range(seq.size(0) // (1 + nb_digits)): - n_op = seq[(1 + nb_digits) * t] - if t > 0: - s += " " - if recorded_stack_counts is not None: - s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] " - s += f"POP" if n_op % 2 == 1 else f"PSH" - if nb_stacks > 1: - s += f"_{n_op//2}" - for d in range(nb_digits): - if seq[(1 + nb_digits) * t + 1 + d] == -1: - s += " ?" - else: - s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}" - return s +def seq_to_str(seq, nb_stacks, nb_digits): + def n_to_str(n): + if n < 0: + return "?" + elif n < 2 * nb_stacks: + s = f"POP" if n % 2 == 1 else f"PSH" + if nb_stacks > 1: + s += f"_{n//2}" + return s + elif n < 2 * nb_stacks + 10: + return f"{n - 2 * nb_stacks}" + else: + return "#" + + return " ".join([n_to_str(x.item()) for x in seq]) ###################################################################### if __name__ == "__main__": + seq, recorded_stack_counts = generate_sequences( + nb=3, + nb_steps=6, + nb_stacks=3, + nb_digits=3, + ) + + sep = torch.full((seq.size(0), 1), seq.max() + 1) + + seq = torch.cat([seq, sep, seq], dim=1) + + for n in range(min(10, seq.size(0))): + print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3)) + + remove_popped_values(seq, 3, 3) + + print() + + for n in range(min(10, seq.size(0))): + print(seq_to_str(seq[n], nb_stacks=3, nb_digits=3)) + + exit(0) + nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1 seq, recorded_stack_counts = generate_sequences( nb=nb, @@ -101,6 +132,8 @@ if __name__ == "__main__": print("-- PREPARED FOR TEST -----------------") + print("SANITY", seq.size()) + remove_popped_values(seq, nb_stacks, nb_digits) for n in range(min(10, seq.size(0))): diff --git a/tasks.py b/tasks.py index 727b196..218ff36 100755 --- a/tasks.py +++ b/tasks.py @@ -250,7 +250,13 @@ class PicoCLVR(Task): # Make a list of strings from a tensor def detensorize(self, x): - return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + def id2token(t): + try: + return self.id2token[t.item()] + except KeyError: + return "?" + + return [" ".join([id2token(t) for t in r]) for r in x] # trim all the tensors in the tuple z to remove as much token from # left and right in the first tensor. If z is a tuple, all its @@ -888,7 +894,10 @@ class Stack(Task): def compute_nb_correct(input): result = input.clone() stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) + ar_mask = (result != input).long() + result *= 1 - ar_mask + masked_inplace_autoregression( model, self.batch_size, @@ -923,10 +932,12 @@ class Stack(Task): stack.remove_popped_values(result, self.nb_stacks, self.nb_digits) ar_mask = (result != input).long() - # for n in range(result.size(0)): - # logger( - # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" - # ) + for n in range(result.size(0)): + logger( + f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" + ) + + result *= 1 - ar_mask masked_inplace_autoregression( model, @@ -1448,7 +1459,13 @@ class Grid(Task): # Make a list of strings from a tensor def tensor2str(self, x): - return [" ".join([self.id2token[t.item()] for t in r]) for r in x] + def id2token(t): + try: + return self.id2token[t.item()] + except KeyError: + return "?" + + return [" ".join([id2token(t) for t in r]) for r in x] # trim all the tensors in the tuple z to remove as much token from # left and right in the first tensor. If z is a tuple, all its -- 2.39.5