From f08778775c6137993f45396408b1a50bf023e5be Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Sat, 20 Aug 2022 07:47:14 +0200 Subject: [PATCH] Replaced --synthesis_sampling with --deterministic_synthesis. --- main.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index f6934b7..ee44ebe 100755 --- a/main.py +++ b/main.py @@ -65,8 +65,8 @@ parser.add_argument('--nb_blocks', parser.add_argument('--dropout', type = float, default = 0.1) -parser.add_argument('--synthesis_sampling', - action='store_true', default = True) +parser.add_argument('--deterministic_synthesis', + action='store_true', default = False) parser.add_argument('--no_checkpoint', action='store_true', default = False) @@ -132,11 +132,11 @@ def autoregression( for s in range(first, input.size(1)): output = model(input) logits = output[:, s] - if args.synthesis_sampling: + if args.deterministic_synthesis: + t_next = logits.argmax(1) + else: dist = torch.distributions.categorical.Categorical(logits = logits) t_next = dist.sample() - else: - t_next = logits.argmax(1) input[:, s] = t_next return results -- 2.39.5