desc="predict",
total=imt_set.size(0) // args.physical_batch_size,
):
- masks = imt[:, 1]
- imt = imt * (1 - masks[:, None]) # paranoia
+ # some paranoia
+ imt = imt.clone()
+ imt[:, 0] = imt[:, 0] * (1 - imt[:1])
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model(imt[:, 0] * 2 + imt[:, 1])
dist = torch.distributions.categorical.Categorical(logits=logits)
changed = True
for it in range(args.diffusion_nb_iterations):
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
- logits = model(input)
+ logits = model(input * 2 + masks)
dist = torch.distributions.categorical.Categorical(logits=logits)
output = dist.sample()
changed = changed & (update != input).max(dim=1).values
input[changed] = update[changed]
- return input
+ return all_input
######################################################################
log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
-def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
+######################################################################
+
+
+def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
# train
one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=True)
imt_set = IMT_batch_prediction(quizzes.to(local_device))
result = predict(model, imt_set, local_device=local_device).to("cpu")
masks = imt_set[:, 1].to("cpu")
+
correct = (quizzes == result).min(dim=1).values.long()
correct_parts = (2 * correct - 1)[:, None] * masks.reshape(masks.size(0), 4, -1)[
:, :, 1
]
predicted_parts = correct_parts.abs()
+
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,
f"culture_prediction_{n_epoch}_{model.id}.png",
# None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
multithread_execution(
- one_train_test_epoch,
+ one_complete_epoch,
[(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)],
)