return y
-def ae_generate(model, nb, local_device=main_device, desc="generate"):
+def ae_generate(model, nb, local_device=main_device):
model.eval().to(local_device)
all_input = quiz_machine.pure_noise(nb, local_device)
all_masks = all_input.new_full(all_input.size(), 1)
+ all_changed = torch.full((all_input.size(0),), True, device=all_input.device)
- src = zip(
- all_input.split(args.physical_batch_size),
- all_masks.split(args.physical_batch_size),
- )
+ for it in range(args.diffusion_nb_iterations):
+ if not all_changed.any():
+ break
- if desc is not None:
- src = tqdm.tqdm(
- src,
- dynamic_ncols=True,
- desc="generate",
- total=all_input.size(0) // args.physical_batch_size,
+ sub_input = all_input[all_changed].clone()
+ sub_masks = all_masks[all_changed].clone()
+ sub_changed = all_changed[all_changed].clone()
+
+ src = zip(
+ sub_input.split(args.physical_batch_size),
+ sub_masks.split(args.physical_batch_size),
+ sub_changed.split(args.physical_batch_size),
)
- for input, masks in src:
- changed = True
- for it in range(args.diffusion_nb_iterations):
+ for input, masks, changed in src:
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model(input * 2 + masks)
dist = torch.distributions.categorical.Categorical(logits=logits)
output = dist.sample()
-
r = prioritized_rand(input != output)
mask_changes = (r <= args.diffusion_proba_corruption).long() * masks
update = (1 - mask_changes) * input + mask_changes * output
- if update.equal(input):
- break
- else:
- changed = changed & (update != input).max(dim=1).values
- input[changed] = update[changed]
+ changed[...] = changed & (update != input).max(dim=1).values
+ input[...] = update
+
+ a = all_changed.clone()
+ all_input[a] = sub_input
+ all_masks[a] = sub_masks
+ all_changed[a] = sub_changed
return all_input
generator_id = model.id
c_quizzes = ae_generate(
- model=model,
- nb=args.physical_batch_size,
- local_device=local_device,
- desc=None,
+ model=model, nb=args.physical_batch_size * 10, local_device=local_device
)
# Select the ones that are solved properly by some models and
mask_generate = quiz_machine.make_quiz_mask(
quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
- result = ae_generate(
- model,
- (1 - mask_generate) * quizzes,
- mask_generate,
- )
+ result = ae_generate(model, (1 - mask_generate) * quizzes, mask_generate)
record.append(result)
result = torch.cat(record, dim=0)