From 49738bb51b386e62f86f861237cbe32b7a2ad479 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 4 Jul 2023 18:08:55 +0200 Subject: [PATCH] Update. --- main.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index beafc19..b907e60 100755 --- a/main.py +++ b/main.py @@ -1091,7 +1091,7 @@ class TaskExpr(Task): result = input.clone() filler, space = self.char2id["#"], self.char2id[" "] ar_mask = (result == space).long().cumsum(dim=1).clamp(max=1) - result = (1 - ar_mask) * result + filler * ar_mask + result = (1 - ar_mask) * result + ar_mask * filler masked_inplace_autoregression( model, self.batch_size, result, ar_mask, device=self.device ) @@ -1113,16 +1113,19 @@ class TaskExpr(Task): result = input.clone() filler, space = self.char2id["#"], self.char2id[" "] ar_mask = (result == space).long().cumsum(dim=1).clamp(max=1) - result = (1 - ar_mask) * result + filler * ar_mask + result = (1 - ar_mask) * result + ar_mask * filler for n in range(result.size(0)): s = "".join([self.id2char[k.item()] for k in result[n]]) log_string(f"test_before {s}") masked_inplace_autoregression( model, self.batch_size, result, ar_mask, device=self.device ) + correct = (1 - ar_mask) * space + ar_mask * input for n in range(result.size(0)): s = "".join([self.id2char[k.item()] for k in result[n]]) log_string(f"test_after {s}") + s = "".join([self.id2char[k.item()] for k in correct[n]]) + log_string(f"correct {s}") ############################################################## model.train(t) -- 2.20.1