return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
-def predict(model, imt_set, local_device=main_device):
+def predict(model, imt_set, local_device=main_device, desc="predict"):
model.eval().to(local_device)
record = []
- for imt in tqdm.tqdm(
- imt_set.split(args.physical_batch_size),
- dynamic_ncols=True,
- desc="predict",
- total=imt_set.size(0) // args.physical_batch_size,
- ):
+ src = imt_set.split(args.physical_batch_size)
+
+ if desc is not None:
+ src = tqdm.tqdm(
+ src,
+ dynamic_ncols=True,
+ desc=desc,
+ total=imt_set.size(0) // args.physical_batch_size,
+ )
+
+ for imt in src:
# some paranoia
imt = imt.clone()
imt[:, 0] = imt[:, 0] * (1 - imt[:, 1])
[input[:, None], masks_with_hints[:, None], targets[:, None]], dim=1
)
- result = predict(model, imt_set, local_device=local_device)
+ result = predict(model, imt_set, local_device=local_device, desc=None)
result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
return result
return y
-def generate(model, nb, local_device=main_device):
+def generate(model, nb, local_device=main_device, desc="generate"):
model.eval().to(local_device)
all_input = quiz_machine.pure_noise(nb, local_device)
all_masks = all_input.new_full(all_input.size(), 1)
- for input, masks in tqdm.tqdm(
- zip(
- all_input.split(args.physical_batch_size),
- all_masks.split(args.physical_batch_size),
- ),
- dynamic_ncols=True,
- desc="generate",
- total=all_input.size(0) // args.physical_batch_size,
- ):
+ src = zip(
+ all_input.split(args.physical_batch_size),
+ all_masks.split(args.physical_batch_size),
+ )
+
+ if desc is not None:
+ src = tqdm.tqdm(
+ src,
+ dynamic_ncols=True,
+ desc="generate",
+ total=all_input.size(0) // args.physical_batch_size,
+ )
+
+ for input, masks in src:
changed = True
for it in range(args.diffusion_nb_iterations):
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
generator_id = model.id
c_quizzes = generate(
- moel=copy_for_inference(model),
+ model=model,
nb=args.physical_batch_size,
local_device=local_device,
+ desc=None,
)
nb_correct, nb_wrong = 0, 0
+
for i, model in enumerate(models):
model = copy.deepcopy(model).to(local_device).eval()
result = predict_full(model, c_quizzes, local_device=local_device)
nb_validated += to_keep.long().sum()
record.append(c_quizzes[to_keep])
+ log_string(f"generate_c_quizzes {nb_validated}")
+
return torch.cat(record)