"--task",
type=str,
default="twotargets",
- help="byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
+ help="file, byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
)
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
+##############################
+# filetask
+
+parser.add_argument("--filetask_file", type=str, default=None)
+
##############################
# rpl options
######################################################################
default_task_args = {
+ "file": {
+ "model": "37M",
+ "batch_size": 25,
+ "nb_train_samples": 250000,
+ "nb_test_samples": 10000,
+ },
"addition": {
"model": "352M",
"batch_size": 25,
######################################################################
-if args.task == "byheart":
+if args.task == "file":
+ assert (
+ args.filetask_file is not None
+ ), "You have to specify the task file with --filetask_file <filename>"
+ task = tasks.TaskFromFile(
+ args.filetask_file,
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ device=device,
+ )
+ args.max_percents_of_test_in_train = 0
+
+elif args.task == "byheart":
task = tasks.SandBox(
problem=problems.ProblemByHeart(),
nb_train_samples=args.nb_train_samples,
pass
+class TaskFromFile(Task):
+ def tensorize(self, pairs):
+ len_max = max([len(x[0]) for x in pairs])
+
+ input = torch.cat(
+ [
+ torch.tensor(
+ [
+ [self.char2id[c] for c in s[0] + "#" * (len_max - len(s[0]))]
+ for s in pairs
+ ]
+ )
+ ],
+ 0,
+ ).to("cpu")
+
+ pred_mask = torch.cat(
+ [
+ torch.tensor(
+ [
+ [int(c) for c in s[1] + "0" * (len_max - len(s[1]))]
+ for s in pairs
+ ]
+ )
+ ],
+ 0,
+ ).to("cpu")
+
+ return input, pred_mask
+
+ # trim all the tensors in the tuple z to remove as much token from
+ # left and right in the first tensor. If z is a tuple, all its
+ # elements are trimed according to the triming for the first
+ def trim(self, z, token="#"):
+ n = self.char2id[token]
+ if type(z) == tuple:
+ x = z[0]
+ i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+ a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+ return tuple([t[:, a:b] for t in z])
+ else:
+ i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+ a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+ return z[:, a:b]
+
+ def __init__(
+ self,
+ filename,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ device=torch.device("cpu"),
+ ):
+ self.batch_size = batch_size
+ self.device = device
+
+ pairs = []
+ with open(filename, "r") as f:
+ for _ in range(nb_train_samples + nb_test_samples):
+ sequence = f.readline().strip()
+ pred_mask = f.readline().strip()
+ assert len(sequence) == len(pred_mask)
+ assert set(pred_mask) == {"0", "1", "2"}, f"{set(pred_mask)}"
+ pairs.append((sequence, pred_mask))
+
+ symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"]))
+ print("SANITY", symbols)
+ self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
+ self.id2char = dict([(n, c) for c, n in self.char2id.items()])
+
+ print(self.char2id)
+
+ self.train_input, self.train_pred_masks = self.tensorize(
+ pairs[:nb_train_samples]
+ )
+ self.test_input, self.test_pred_masks = self.tensorize(pairs[nb_train_samples:])
+
+ def batches(self, split="train", nb_to_use=-1, desc=None):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ if nb_to_use > 0:
+ input = input[:nb_to_use]
+ if desc is None:
+ desc = f"epoch-{split}"
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
+ ):
+ yield self.trim(batch).to(self.device)
+
+ def vocabulary_size(self):
+ return len(self.char2id)
+
+ def tensor2str(self, t):
+ print(f"{type(t)=}")
+ return ["".join([self.id2char[x.item()] for x in s]) for s in t]
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ ):
+ correct = self.trim(self.test_input[:1000]).to(self.device)
+ result = correct.clone()
+ pred_mask = self.test_pred_masks[:1000, : result.size(1)].to(self.device)
+ ar_mask = (pred_mask > 0).long()
+ result *= 1 - ar_mask # paraaaaanoiaaaaaaa
+
+ logger(f"----------------------------------------------------------")
+
+ for e in self.tensor2str(result[:10]):
+ logger(f"test_before {e}")
+
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ logger(f"----------------------------------------------------------")
+
+ for e, c in zip(self.tensor2str(result[:10]), self.tensor2str(correct[:10])):
+ logger(f"test_after {e}")
+ logger(f"correct {c}")
+
+ logger(f"----------------------------------------------------------")
+
+ err_mask = (pred_mask == 2).long()
+ nb_total = err_mask.sum().item()
+ nb_correct = ((correct == result).long() * err_mask).sum().item()
+
+ logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
+ logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
+
+
####################
import problems
self.device = device
self.batch_size = batch_size
self.grid_factory = grid.GridFactory(size=size)
+ self.fraction_play = fraction_play
if logger is not None:
logger(
fraction_play=fraction_play,
progress_bar=lambda r: tqdm.tqdm(r),
)
+
self.test_descr = self.grid_factory.generate_samples(
nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
)
+ if fraction_play > 0:
+ self.play_descr = self.grid_factory.generate_samples(
+ nb=25, fraction_play=1.0, progress_bar=lambda r: tqdm.tqdm(r)
+ )
+ else:
+ self.play_descr = []
+
# Build the tokenizer
tokens = set()
- for d in [self.train_descr, self.test_descr]:
+ for d in [self.train_descr, self.test_descr, self.play_descr]:
for s in d:
for t in s.strip().split(" "):
tokens.add(t)
self.t_nul = self.token2id["#"]
self.t_true = self.token2id["true"]
self.t_false = self.token2id["false"]
+ self.t_pipe = self.token2id["|"]
# Tokenize the train and test sets
self.train_input = self.str2tensor(self.train_descr)
self.test_input = self.str2tensor(self.test_descr)
+ self.play_input = (
+ None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
+ )
def batches(self, split="train"):
assert split in {"train", "test"}
logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
- if n_epoch == 5 or n_epoch == 10 or n_epoch == 20:
- if save_attention_image is None:
- logger("no save_attention_image (is pycairo installed?)")
- else:
- for k in range(10):
- ns = k # torch.randint(self.test_input.size(0), (1,)).item()
- input = self.test_input[ns : ns + 1].clone()
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
- model.record_attention(True)
- model(BracketedSequence(input))
- model.train(t)
- ram = model.retrieve_attention()
- model.record_attention(False)
-
- tokens_output = [self.id2token[t.item()] for t in input[0]]
- tokens_input = ["n/a"] + tokens_output[:-1]
- for n_head in range(ram[0].size(1)):
- filename = os.path.join(
- result_dir,
- f"sandbox_attention_epoch_{n_epoch}_sample_{k}_head_{n_head}.pdf",
- )
- attention_matrices = [m[0, n_head] for m in ram]
- save_attention_image(
- filename,
- tokens_input,
- tokens_output,
- attention_matrices,
- k_top=10,
- # min_total_attention=0.9,
- token_gap=12,
- layer_gap=50,
- )
- logger(f"wrote {filename}")
+ if self.play_input is not None:
+ result = self.play_input.clone()
+ ar_mask = (result == self.t_pipe).long().cumsum(dim=1).clamp(max=1)
+ result *= 1 - ar_mask # paraaaaanoiaaaaaaa
+
+ logger(f"----------------------------------------------------------")
+
+ for e in self.tensor2str(result[:10]):
+ logger(f"play_before {e}")
+
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ logger(f"----------------------------------------------------------")
+
+ for e in self.tensor2str(result[:10]):
+ logger(f"play_after {e}")
+
+ logger(f"----------------------------------------------------------")
######################################################################