From 233f57347c9560aec2f3cbaf001a8efa56a0243b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 8 Jul 2023 12:15:56 +0200 Subject: [PATCH] Update. --- expr.py | 4 +--- tasks.py | 7 ++++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/expr.py b/expr.py index 818360b..e539fcb 100755 --- a/expr.py +++ b/expr.py @@ -24,7 +24,7 @@ def random_expr(variables, budget): else: return str(torch.randint(10, (1,)).item()) else: - op = torch.randint(4, (1,)).item() + op = torch.randint(3, (1,)).item() if op == 0: e = random_expr(variables, budget - 2) if ("+" in e or "-" in e or "*" in e) and (e[0] != "(" or e[-1] != ")"): @@ -38,8 +38,6 @@ def random_expr(variables, budget): if op == 1: return e1 + "+" + e2 elif op == 2: - return e1 + "+" + e2 - elif op == 3: return e1 + "*" + e2 diff --git a/tasks.py b/tasks.py index b277b96..463d94c 100755 --- a/tasks.py +++ b/tasks.py @@ -937,11 +937,12 @@ class Expr(Task): input = self.tensorize(sequences) result = input.clone() - ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1) + s = (result == self.space).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) result = (1 - ar_mask) * result + ar_mask * self.filler - # for n in range(result.size(0)): - # logger(f"test_before {self.seq2str(result[n])}") + for n in range(result.size(0)): + logger(f"test_before {self.seq2str(result[n])}") masked_inplace_autoregression( model, -- 2.39.5