######################################################################
+ def make_mask_hints(self, mask_generate, nb_hints):
+ if nb_hints == 0:
+ mask_hints = None
+ else:
+ u = (
+ torch.rand(mask_generate.size(), device=mask_generate.device)
+ * mask_generate
+ )
+ mask_hints = (
+ u > u.sort(dim=1, descending=True).values[:, nb_hints, None]
+ ).long()
+
+ return mask_hints
+
# This function gets a clean target x_0, and a mask indicating which
# part to generate (conditionnaly to the others), and returns the
# logits starting from a x_t|X_0=x_0 picked at random with t random
def logits_hat_x_0_from_random_iteration(
- self, model, x_0, mask_generate, prompt_noise=0.0
+ self, model, x_0, mask_generate, nb_hints=0, prompt_noise=0.0
):
+ noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
+
+ single_iteration = (
+ mask_generate.sum(dim=1) < mask_generate.size(1) // 2
+ ).long()[:, None]
+
+ mask_hints = self.make_mask_hints(mask_generate, nb_hints)
+
+ if mask_hints is None:
+ mask_start = mask_generate
+ else:
+ mask_start = mask_generate * (1 - mask_hints)
+
# We favor iterations near the clean signal
probs_iterations = 0.1 ** torch.linspace(
t = dist.sample() + 1
- x_t = self.sample_x_t_given_x_0(x_0, t)
+ x_t = single_iteration * noise + (
+ 1 - single_iteration
+ ) * self.sample_x_t_given_x_0(x_0, t)
# Only the part to generate is degraded, the rest is a perfect
# noise-free conditionning
######################################################################
- def ae_generate(self, model, x_0, mask_generate, mask_hints=None):
+ def generate(self, model, x_0, mask_generate, nb_hints=0):
noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
single_iteration = (
mask_generate.sum(dim=1) < mask_generate.size(1) // 2
).long()[:, None]
+ mask_hints = self.make_mask_hints(mask_generate, nb_hints)
+
if mask_hints is None:
mask_start = mask_generate
else:
parser.add_argument("--reboot", action="store_true", default=False)
+parser.add_argument("--nb_have_to_be_correct", type=int, default=3)
+
+parser.add_argument("--nb_have_to_be_wrong", type=int, default=1)
+
+parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5)
+
# ----------------------------------
parser.add_argument("--model", type=str, default="37M")
######################################################################
-def model_ae_proba_solutions(model, input, log_probas=False, reduce=True):
+def model_proba_solutions(model, input, log_probas=False, reduce=True):
record = []
for x_0 in input.split(args.batch_size):
######################################################################
-def ae_batches(
+def batches(
quiz_machine,
nb,
data_structures,
######################################################################
-def run_ae_test(
+def run_test(
model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device, prefix=None
):
if prefix is None:
nb_test_samples, acc_test_loss = 0, 0.0
- for x_0, mask_generate in ae_batches(
+ for x_0, mask_generate in batches(
quiz_machine,
args.nb_test_samples,
data_structures,
nb_correct, nb_total, record_d, record_nd = 0, 0, [], []
- for x_0, mask_generate in ae_batches(
+ for x_0, mask_generate in batches(
quiz_machine,
args.nb_test_samples,
data_structures,
c_quizzes=c_quizzes,
desc="test",
):
- result = diffuser.ae_generate(
- model, (1 - mask_generate) * x_0, mask_generate
- )
+ result = diffuser.generate(model, (1 - mask_generate) * x_0, mask_generate)
correct = (result == x_0).min(dim=1).values.long()
predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[
:, :, 1
######################################################################
-def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device):
+def one_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device):
model.train().to(local_device)
optimizer_to(model.optimizer, local_device)
# scaler = torch.amp.GradScaler("cuda")
- for x_0, mask_generate in ae_batches(
+ for x_0, mask_generate in batches(
quiz_machine,
args.nb_train_samples,
data_structures,
f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
)
- model.test_accuracy = run_ae_test(
+ model.test_accuracy = run_test(
model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device
)
if args.nb_test_alien_samples > 0:
- run_ae_test(
+ run_test(
model,
alien_quiz_machine,
n_epoch,
models,
c_quizzes,
local_device,
- nb_have_to_be_correct=3,
- nb_have_to_be_wrong=1,
- nb_mistakes_to_be_wrong=5,
+ nb_have_to_be_correct,
+ nb_have_to_be_wrong,
+ nb_mistakes_to_be_wrong,
nb_hints=0,
nb_runs=1,
):
+ ######################################################################
+ # If too many with process per-batch
+
if c_quizzes.size(0) > args.inference_batch_size:
record = []
for q in c_quizzes.split(args.inference_batch_size):
)
)
- return (torch.cat([tk for tk, _ in record], dim=0)), (
- torch.cat([w for _, w in record], dim=0)
- )
+ r = []
+ for k in range(len(record[0])):
+ r.append(torch.cat([x[k] for x in record], dim=0))
+
+ return tuple(r)
+ ######################################################################
record_wrong = []
nb_correct, nb_wrong = 0, 0
sub_correct, sub_wrong = False, True
for _ in range(nb_runs):
- if nb_hints == 0:
- mask_hints = None
- else:
- u = (
- torch.rand(mask_generate.size(), device=mask_generate.device)
- * mask_generate
- )
- mask_hints = (
- u > u.sort(dim=1, descending=True).values[:, nb_hints, None]
- ).long()
-
- result = ae_generate(
+ result = diffuser.generate(
model=model,
x_0=c_quizzes,
mask_generate=mask_generate,
- mask_hints=mask_hints,
+ nb_hints=nb_hints,
)
nb_mistakes = (result != c_quizzes).long().sum(dim=1)
######################################################################
-def generate_ae_c_quizzes(models, nb, local_device=main_device):
+def generate_c_quizzes(models, nb, local_device=main_device):
# To be thread-safe we must make copies
def copy_for_inference(model):
quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
)
- c_quizzes = ae_generate(model, template, mask_generate)
+ c_quizzes = diffuser.generate(model, template, mask_generate)
to_keep = quiz_machine.problem.trivial(c_quizzes) == False
c_quizzes = c_quizzes[to_keep]
models,
c_quizzes,
local_device,
+ nb_have_to_be_correct=args.nb_have_to_be_correct,
+ nb_have_to_be_wrong=args.nb_have_to_be_wrong,
+ nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong,
nb_hints=args.nb_hints,
nb_runs=args.nb_runs,
)
c_quizzes = c_quizzes.to(main_device)
with torch.autograd.no_grad():
+ to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation(
+ models,
+ c_quizzes,
+ main_device,
+ nb_have_to_be_correct=args.nb_have_to_be_correct,
+ nb_have_to_be_wrong=0,
+ nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong,
+ nb_hints=0,
+ )
+
if solvable_only:
- to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation(
- models,
- c_quizzes,
- main_device,
- nb_have_to_be_correct=2,
- nb_have_to_be_wrong=0,
- nb_hints=0,
- )
c_quizzes = c_quizzes[to_keep]
+ nb_correct = nb_correct[to_keep]
+ nb_wrong = nb_wrong[to_keep]
- comments = []
+ comments = []
- for c, w in zip(nb_correct, nb_wrong):
- comments.append("nb_correct {c} nb_wrong {w}")
+ for c, w in zip(nb_correct, nb_wrong):
+ comments.append(f"nb_correct {c} nb_wrong {w}")
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,
mask_generate = quiz_machine.make_quiz_mask(
quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
- result = ae_generate(
+ result = generate(
model,
(1 - mask_generate) * quizzes,
mask_generate,
records.append(fun(*args))
for args in arguments:
- t = threading.Thread(target=threadable_fun, daemon=True, args=args)
-
# To get a different sequence between threads
log_string(f"dummy_rand {torch.rand(1)}")
+ t = threading.Thread(target=threadable_fun, daemon=True, args=args)
threads.append(t)
t.start()
nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
c_quizzes, agreements = multithread_execution(
- generate_ae_c_quizzes,
+ generate_c_quizzes,
[(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
)
solvable_only=True,
)
- u = c_quizzes.reshape(c_quizzes.size(0), 4, -1)[:, 1:]
+ u = c_quizzes.reshape(c_quizzes.size(0), 4, -1)[:, :, 1:]
i = (u[:, 2] != u[:, 3]).long().sum(dim=1).sort(descending=True).indices
save_c_quizzes_with_scores(
# None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
multithread_execution(
- one_ae_epoch,
+ one_epoch,
[
(model, quiz_machine, n_epoch, c_quizzes, gpu)
for model, gpu in zip(weakest_models, gpus)