From: François Fleuret Date: Wed, 18 Sep 2024 14:31:50 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=62f604a75378d2bbf1546b60fcf8eea92be928f8;p=culture.git Update. --- diff --git a/main.py b/main.py index c6eedfb..464b217 100755 --- a/main.py +++ b/main.py @@ -480,41 +480,42 @@ def prioritized_rand(low): return y -def ae_generate(model, nb, local_device=main_device, desc="generate"): +def ae_generate(model, nb, local_device=main_device): model.eval().to(local_device) all_input = quiz_machine.pure_noise(nb, local_device) all_masks = all_input.new_full(all_input.size(), 1) + all_changed = torch.full((all_input.size(0),), True, device=all_input.device) - src = zip( - all_input.split(args.physical_batch_size), - all_masks.split(args.physical_batch_size), - ) + for it in range(args.diffusion_nb_iterations): + if not all_changed.any(): + break - if desc is not None: - src = tqdm.tqdm( - src, - dynamic_ncols=True, - desc="generate", - total=all_input.size(0) // args.physical_batch_size, + sub_input = all_input[all_changed].clone() + sub_masks = all_masks[all_changed].clone() + sub_changed = all_changed[all_changed].clone() + + src = zip( + sub_input.split(args.physical_batch_size), + sub_masks.split(args.physical_batch_size), + sub_changed.split(args.physical_batch_size), ) - for input, masks in src: - changed = True - for it in range(args.diffusion_nb_iterations): + for input, masks, changed in src: with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): logits = model(input * 2 + masks) dist = torch.distributions.categorical.Categorical(logits=logits) output = dist.sample() - r = prioritized_rand(input != output) mask_changes = (r <= args.diffusion_proba_corruption).long() * masks update = (1 - mask_changes) * input + mask_changes * output - if update.equal(input): - break - else: - changed = changed & (update != input).max(dim=1).values - input[changed] = update[changed] + changed[...] = changed & (update != input).max(dim=1).values + input[...] = update + + a = all_changed.clone() + all_input[a] = sub_input + all_masks[a] = sub_masks + all_changed[a] = sub_changed return all_input @@ -709,10 +710,7 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device): generator_id = model.id c_quizzes = ae_generate( - model=model, - nb=args.physical_batch_size, - local_device=local_device, - desc=None, + model=model, nb=args.physical_batch_size * 10, local_device=local_device ) # Select the ones that are solved properly by some models and @@ -847,11 +845,7 @@ if args.quizzes is not None: mask_generate = quiz_machine.make_quiz_mask( quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) - result = ae_generate( - model, - (1 - mask_generate) * quizzes, - mask_generate, - ) + result = ae_generate(model, (1 - mask_generate) * quizzes, mask_generate) record.append(result) result = torch.cat(record, dim=0)