From c9dbc3abf436df8af1379d04ab51159e821496f1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 19 Jul 2023 13:54:59 +0200 Subject: [PATCH] Update. --- main.py | 16 ++++++- rpl.py | 125 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ tasks.py | 118 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 258 insertions(+), 1 deletion(-) create mode 100755 rpl.py diff --git a/main.py b/main.py index efcc0dd..63e6668 100755 --- a/main.py +++ b/main.py @@ -36,7 +36,7 @@ parser.add_argument( "--task", type=str, default="sandbox", - help="sandbox, picoclvr, mnist, maze, snake, stack, expr, world", + help="sandbox, picoclvr, mnist, maze, snake, stack, expr, rpl, world", ) parser.add_argument("--log_filename", type=str, default="train.log", help=" ") @@ -206,6 +206,12 @@ default_task_args = { "nb_train_samples": 1000000, "nb_test_samples": 10000, }, + "rpl": { + "nb_epochs": 40, + "batch_size": 25, + "nb_train_samples": 1000000, + "nb_test_samples": 10000, + }, "world": { "nb_epochs": 10, "batch_size": 25, @@ -419,6 +425,14 @@ elif args.task == "expr": device=device, ) +elif args.task == "rpl": + task = tasks.RPL( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + device=device, + ) + elif args.task == "world": task = tasks.World( nb_train_samples=args.nb_train_samples, diff --git a/rpl.py b/rpl.py new file mode 100755 index 0000000..42db38c --- /dev/null +++ b/rpl.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + + +def rpl_exec(program, stack): + for op in program: + if op == "add": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(a + b) + elif op == "min": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(min(a, b)) + elif op == "max": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(max(a, b)) + elif op == "swp": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack.append(a) + stack.append(b) + elif op == "rep": + if len(stack) > 1: + a, b = stack.pop(), stack.pop() + stack += [b] * a + elif op == "dup": + if len(stack) > 0: + a = stack.pop() + stack.append(a) + stack.append(a) + elif op == "del": + if len(stack) > 0: + a = stack.pop() + else: + raise ValueError(f"Unknown instruction {op}") + + +rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"] + +###################################################################### + + +def generate(nb_values=3, max_input=9, prog_len=6, nb_runs=5): + prog_len = 1 + torch.randint(prog_len - 1, (1,)).item() + prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))] + + result = [] + for _ in range(nb_runs): + stack = [x.item() for x in torch.randint(max_input + 1, (nb_values,))] + result = result + [""] + stack + rpl_exec(prog, stack) + result = result + [""] + stack + + result = result + [""] + prog + result = result + [""] + return result + + +def next_marker(seq, tokens, start=0): + pos = None + for t in tokens: + try: + i = seq.index(t, start) + if pos is None or i < pos: + pos = i + except ValueError: + pass + return pos + + +def check(seq): + io = [] + k = 0 + while seq[k] == "": + o = next_marker(seq, [""], start=k + 1) + e = next_marker(seq, ["", ""], start=o) + if o is None or e is None: + raise ValueError("Invalid input/output") + io.append((seq[k + 1 : o], seq[o + 1 : e])) + k = e + + if seq[k] == "": + e = next_marker(seq, [""], start=k) + if e is None: + prog = [] + else: + prog = seq[k + 1 : e] + + nb_total, nb_errors = 0, 0 + + if len(set(prog) - set(rpl_ops)) > 0: + for stack, target_stack in io: + nb_total += len(target_stack) + nb_errors += len(target_stack) + + else: + for stack, target_stack in io: + # print(f"INIT {stack} PROG {prog}") + rpl_exec(prog, stack) + # print(f"CHECK {stack} REF {target_stack} NB_ERROR {abs(len(stack) - len(target_stack))+sum([0 if x == y else 1 for x, y in zip(stack, target_stack)])}") + nb_total += len(target_stack) + nb_errors += abs(len(stack) - len(target_stack)) + nb_errors += sum([0 if x == y else 1 for x, y in zip(stack, target_stack)]) + + return nb_total, nb_errors + + +###################################################################### + +if __name__ == "__main__": + seq = generate() + print(seq) + seq[3] = 7 + print(seq) + print(check(seq)) diff --git a/tasks.py b/tasks.py index c5418b4..a3d47f5 100755 --- a/tasks.py +++ b/tasks.py @@ -1021,6 +1021,124 @@ class Stack(Task): ############################################################## +###################################################################### + +import rpl + + +class RPL(Task): + def tensorize(self, sequences): + len_max = max([len(x) for x in sequences]) + return torch.cat( + [ + torch.tensor( + [ + [ + self.token2id[str(c)] + for c in s + [""] * (len_max - len(s)) + ] + for s in sequences + ] + ) + ], + 0, + ).to(self.device) + + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + device=torch.device("cpu"), + ): + super().__init__() + + self.batch_size = batch_size + self.device = device + + train_sequences = [ + rpl.generate() + for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data") + ] + test_sequences = [ + rpl.generate() for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data") + ] + + symbols = list( + set([""] + [x for l in train_sequences + test_sequences for x in l]) + ) + val_max = max([x if type(x) is int else 0 for x in symbols]) + symbols = list(filter(lambda x: type(x) is str, symbols)) + symbols.sort() + symbols += [str(n) for n in range(val_max + 1)] + print(f"{val_max=}") + self.token2id = dict([(c, n) for n, c in enumerate(symbols)]) + self.id2token = dict([(n, c) for c, n in self.token2id.items()]) + + self.t_nul, self.t_prog = self.token2id[""], self.token2id[""] + + self.train_input = self.tensorize(train_sequences) + self.test_input = self.tensorize(test_sequences) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + last = (batch != self.t_nul).max(0).values.nonzero().max() + 3 + batch = batch[:, :last] + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + def compute_nb_errors(input, nb_to_log=0): + result = input.clone() + s = (result == self.t_prog).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) + result = (1 - ar_mask) * result + ar_mask * self.t_nul + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + if nb_to_log > 0: + for x in result[:nb_to_log]: + s = " ".join([self.id2token[i.item()] for i in x]) + logger(f"check {n_epoch} {s}") + nb_to_log -= min(nb_to_log, result.size(0)) + + sum_nb_total, sum_nb_errors = 0, 0 + for x in result: + seq = [self.id2token[i.item()] for i in x] + nb_total, nb_errors = rpl.check(seq) + sum_nb_total += nb_total + sum_nb_errors += nb_errors + + return sum_nb_total, sum_nb_errors + + test_nb_total, test_nb_errors = compute_nb_errors(self.test_input, nb_to_log=10) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%" + ) + + ###################################################################### -- 2.39.5