From 365adcf76a8505350af9d605e5134135a53f6f74 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 23 Aug 2024 07:32:07 +0200 Subject: [PATCH] Update. --- main.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 2f867db..00a6cd1 100755 --- a/main.py +++ b/main.py @@ -981,10 +981,9 @@ def test_ae(local_device=main_device): logits = model(mygpt.BracketedSequence(result)).x dist = torch.distributions.categorical.Categorical(logits=logits) pred_result = result.clone() - result[not_converged] = ( - (1 - mask_generate) * input + mask_generate * dist.sample() - )[not_converged] - not_converged = (pred_result == result).long().min(dim=1).values == 0 + update = (1 - mask_generate) * input + mask_generate * dist.sample() + result[not_converged] = update[not_converged] + not_converged = (pred_result != result).max(dim=1).values nb_it += 1 print("DEBUG", nb_it, i.long().sum().item()) if not i.any() or nb_it > 100: -- 2.39.5