######################################################################
- def make_mask_hints(self, mask_generate, nb_hints):
- if nb_hints == 0:
+ def make_mask_hints(mask_generate, nb_hints):
+ if nb_hints is None:
mask_hints = None
else:
u = (
torch.rand(mask_generate.size(), device=mask_generate.device)
* mask_generate
)
- mask_hints = (
- u > u.sort(dim=1, descending=True).values[:, nb_hints, None]
- ).long()
+ v = u.sort(dim=1, descending=True).values.gather(
+ dim=1, index=nb_hints[:, None]
+ )
+ mask_hints = (u > v).long()
return mask_hints
# logits starting from a x_t|X_0=x_0 picked at random with t random
def logits_hat_x_0_from_random_iteration(
- self, model, x_0, mask_generate, nb_hints=0, prompt_noise=0.0
+ self, model, x_0, mask_generate, nb_hints=None, prompt_noise=0.0
):
noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
mask_generate.sum(dim=1) < mask_generate.size(1) // 2
).long()[:, None]
- mask_hints = self.make_mask_hints(mask_generate, nb_hints)
-
- if mask_hints is None:
- mask_start = mask_generate
- else:
- mask_start = mask_generate * (1 - mask_hints)
+ mask_hints = self.make_mask_hints(mask_generate, nb_hints) * single_iteration
# We favor iterations near the clean signal
t = dist.sample() + 1
- x_t = single_iteration * noise + (
- 1 - single_iteration
- ) * self.sample_x_t_given_x_0(x_0, t)
-
- # Only the part to generate is degraded, the rest is a perfect
- # noise-free conditionning
-
+ x_T_with_hints = mask_hints * x_0 + (1 - mask_hint) * noise
+ x_t = self.sample_x_t_given_x_0(x_0, t)
+ x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t
x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
# We may inject noise to prevent high-complexity non-structure
######################################################################
- def generate(self, model, x_0, mask_generate, nb_hints=0):
+ def generate(self, model, x_0, mask_generate, nb_hints=None):
noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
single_iteration = (
mask_hints = self.make_mask_hints(mask_generate, nb_hints)
- if mask_hints is None:
- mask_start = mask_generate
- else:
- mask_start = mask_generate * (1 - mask_hints)
-
- x_t = (1 - mask_start) * x_0 + mask_start * noise
+ x_T_with_hints = mask_hints * x_0 + (1 - mask_hint) * noise
+ x_t = self.sample_x_t_given_x_0(x_0, t)
+ x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t
+ x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
changed = True
x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model(x_t_with_mask)
- # logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf")
dist = torch.distributions.categorical.Categorical(logits=logits)
hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
logits = logits_hat_x_0_from_random_iteration(
- model, x_0, mask_generate, prompt_noise=args.prompt_noise
+ model=model,
+ x_0=x_0,
+ mask_generate=mask_generate,
+ prompt_noise=args.prompt_noise,
)
loss_per_token = F.cross_entropy(
logits.transpose(1, 2), x_0, reduction="none"
# Save some images
- if n_epoch < 100:
- for f, record in [("prediction", record_d), ("generation", record_nd)]:
- result, predicted_parts, correct_parts = bag_to_tensors(record)
+ for f, record in [("prediction", record_d), ("generation", record_nd)]:
+ result, predicted_parts, correct_parts = bag_to_tensors(record)
- filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
+ filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- filename,
- quizzes=result[:128],
- predicted_parts=predicted_parts[:128],
- correct_parts=correct_parts[:128],
- )
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ quizzes=result[:128],
+ predicted_parts=predicted_parts[:128],
+ correct_parts=correct_parts[:128],
+ )
- log_string(f"wrote {filename}")
+ log_string(f"wrote {filename}")
return nb_correct / nb_total
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
+ nb_hints = torch.randint(2, (x_0.size(0),), device=x_0.device) * args.nb_hints
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = diffuser.logits_hat_x_0_from_random_iteration(
model=model,
x_0=x_0,
mask_generate=mask_generate,
prompt_noise=args.prompt_noise,
+ nb_hints=nb_hints,
)
loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
nb_have_to_be_correct,
nb_have_to_be_wrong,
nb_mistakes_to_be_wrong,
- nb_hints=0,
+ nb_hints,
nb_runs=1,
):
######################################################################
if c_quizzes.size(0) > args.inference_batch_size:
record = []
- for q in c_quizzes.split(args.inference_batch_size):
+ for q, nh in zip(
+ c_quizzes.split(args.inference_batch_size),
+ nb_hints.split(args.inference_batch_size),
+ ):
record.append(
quiz_validation(
models=models,
nb_have_to_be_correct=nb_have_to_be_correct,
nb_have_to_be_wrong=nb_have_to_be_wrong,
nb_mistakes_to_be_wrong=nb_mistakes_to_be_wrong,
- nb_hints=nb_hints,
+ nb_hints=nh,
nb_runs=nb_runs,
)
)
nb_correct += correct.long()
nb_wrong += wrong.long()
- # log_string(f"{nb_hints=} {nb_correct=}")
- # log_string(f"{nb_hints=} {nb_wrong=}")
-
to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
wrong = torch.cat(record_wrong, dim=1)
to_keep = quiz_machine.problem.trivial(c_quizzes) == False
c_quizzes = c_quizzes[to_keep]
+ nb_hints = torch.full(
+ (c_quizzes.size(0),), args.nb_hints, device=c_quizzes.device
+ )
+
if c_quizzes.size(0) > 0:
to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation(
models,
nb_have_to_be_correct=args.nb_have_to_be_correct,
nb_have_to_be_wrong=args.nb_have_to_be_wrong,
nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong,
- nb_hints=args.nb_hints,
+ nb_hints=nb_hints,
nb_runs=args.nb_runs,
)
nb_have_to_be_correct=args.nb_have_to_be_correct,
nb_have_to_be_wrong=0,
nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong,
- nb_hints=0,
+ nb_hints=None,
)
if solvable_only: