######################################################################
-def model_proba_solutions(model, input, log_probas=False, reduce=True):
- record = []
-
- for x_0 in input.split(args.batch_size):
- loss = 0
-
- for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
- mask_generate = quiz_machine.make_quiz_mask(
- quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
- )
- logits = logits_hat_x_0_from_random_iteration(
- model=model,
- x_0=x_0,
- mask_generate=mask_generate,
- prompt_noise=args.prompt_noise,
- )
- loss_per_token = F.cross_entropy(
- logits.transpose(1, 2), x_0, reduction="none"
- )
- if reduce:
- loss += (loss_per_token * mask_generate).sum(dim=1)
- else:
- loss += loss_per_token * mask_generate
-
- record.append(loss)
-
- loss = torch.cat(record, dim=0)
-
- if log_probas:
- return -loss
- else:
- return (-loss).exp()
-
-
-######################################################################
-
-
-def batches(
- quiz_machine,
- nb,
- data_structures,
- local_device,
- c_quizzes=None,
- alien_quiz_machine=None,
- desc=None,
- batch_size=args.batch_size,
-):
- c_quiz_bags = [] if c_quizzes is None else [c_quizzes.to("cpu")]
-
- full_input, full_mask_generate, _ = quiz_machine.data_input(
- nb,
- c_quiz_bags,
- data_structures=data_structures,
- c_quiz_multiplier=args.c_quiz_multiplier,
- )
-
- src = zip(
- full_input.split(batch_size),
- full_mask_generate.split(batch_size),
- )
-
- if desc is not None:
- src = tqdm.tqdm(
- src,
- dynamic_ncols=True,
- desc=desc,
- total=full_input.size(0) // batch_size,
- )
-
- for input, mask_generate in src:
- yield (
- input.to(local_device),
- mask_generate.to(local_device),
- )
-
-
-def NTC_channel_cat(*x):
- return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
-
-
-def NTC_masked_cross_entropy(output, targets, mask):
- loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
- return (loss_per_token * mask).mean()
-
-
def masked_cross_entropy(output, targets, masks):
loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
return (loss_per_token * masks).sum() / masks.expand_as(loss_per_token).sum()
######################################################################
-
-def run_test(
- model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device, prefix=None
-):
- if prefix is None:
- prefix = ""
- else:
- prefix = prefix + "_"
-
- with torch.autograd.no_grad():
- model.eval().to(local_device)
-
- # Compute the loss
-
- nb_test_samples, acc_test_loss = 0, 0.0
-
- for x_0, mask_generate in batches(
- quiz_machine,
- args.nb_test_samples,
- data_structures,
- local_device,
- c_quizzes=c_quizzes,
- desc="test",
- ):
- logits = diffuser.logits_hat_x_0_from_random_iteration(
- model=model,
- x_0=x_0,
- mask_generate=mask_generate,
- )
- loss = masked_cross_entropy(logits, x_0, mask_generate)
- acc_test_loss += loss.item() * x_0.size(0)
- nb_test_samples += x_0.size(0)
-
- log_string(
- f"{prefix}test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}"
- )
-
- # Compute the accuracy and save some images
-
- nb_correct, nb_total, record_d, record_nd = 0, 0, [], []
-
- for x_0, mask_generate in batches(
- quiz_machine,
- args.nb_test_samples,
- data_structures,
- local_device,
- c_quizzes=c_quizzes,
- desc="test",
- ):
- result = diffuser.generate(model, (1 - mask_generate) * x_0, mask_generate)
- correct = (result == x_0).min(dim=1).values.long()
- predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[
- :, :, 1
- ]
- d = predicted_parts.sum(dim=-1) == 1
- correct = (2 * correct - 1) * d.long()
- nb_correct += (correct == 1).long().sum()
- nb_total += (correct != 0).long().sum()
- correct_parts = predicted_parts * correct[:, None]
- record_d.append((result[d], predicted_parts[d], correct_parts[d]))
- nd = d == False
- record_nd.append((result[nd], predicted_parts[nd], correct_parts[nd]))
-
- log_string(
- f"{prefix}test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
- )
-
- # Save some images
-
- for f, record in [("prediction", record_d), ("generation", record_nd)]:
- result, predicted_parts, correct_parts = bag_to_tensors(record)
-
- filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
-
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- filename,
- quizzes=result[:128],
- predicted_parts=predicted_parts[:128],
- correct_parts=correct_parts[:128],
- )
-
- log_string(f"wrote {filename}")
-
- return nb_correct / nb_total
-
-
-######################################################################
-
-
-def one_epoch_(model, n_epoch, c_quizzes, local_device=main_device):
- model.train().to(local_device)
- optimizer_to(model.optimizer, local_device)
-
- nb_train_samples, acc_train_loss = 0, 0.0
-
- # scaler = torch.amp.GradScaler("cuda")
-
- for x_0, mask_generate in batches(
- quiz_machine,
- args.nb_train_samples,
- data_structures,
- local_device,
- c_quizzes=c_quizzes,
- desc="training",
- ):
- x_0 = x_0.to(local_device)
- mask_generate = mask_generate.to(local_device)
-
- if nb_train_samples % args.batch_size == 0:
- model.optimizer.zero_grad()
-
- nb_hints = torch.randint(2, (x_0.size(0),), device=x_0.device) * args.nb_hints
-
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
- logits = diffuser.logits_hat_x_0_from_random_iteration(
- model=model,
- x_0=x_0,
- mask_generate=mask_generate,
- prompt_noise=args.prompt_noise,
- nb_hints=nb_hints,
- )
-
- loss = masked_cross_entropy(logits, x_0, mask_generate)
- acc_train_loss += loss.item() * x_0.size(0)
- nb_train_samples += x_0.size(0)
-
- loss.backward()
-
- if nb_train_samples % args.batch_size == 0:
- model.optimizer.step()
-
- # scaler.scale(loss).backward()
-
- # if nb_train_samples % args.batch_size == 0:
- # scaler.step(model.optimizer)
-
- # scaler.update()
-
- log_string(
- f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
- )
-
- model.test_accuracy = run_test(
- model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device
- )
-
- if args.nb_test_alien_samples > 0:
- run_test(
- model,
- alien_quiz_machine,
- n_epoch,
- c_quizzes=None,
- local_device=local_device,
- prefix="alien",
- )
-
-
-######################################################################
+# IMT for input / masks / target
def IMT_batch_prediction(input, proba_hints=0.0):
######################################################################
-# IMT for input / masks / target
-
def IMT_batch_generation(input):
nb = input.size(0)
def generate(model, nb, local_device=main_device):
- input = quiz_machine.problem.pure_noise(nb, local_device)
- masks = input.new_full(input.size(), 1)
- masks.reshape(masks.size(0), 4, -1)[:, :, 0] = 0
-
- changed = True
- for it in range(args.diffusion_nb_iterations):
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
- logits = model(input)
- dist = torch.distributions.categorical.Categorical(logits=logits)
- output = dist.sample()
-
- r = prioritized_rand(input != output)
- mask_changes = (r <= args.diffusion_proba_corruption).long() * masks
- update = (1 - mask_changes) * input + mask_changes * output
- if update.equal(input):
- break
- else:
- changed = changed & (update != input).max(dim=1).values
- input[changed] = update[changed]
+ all_input = quiz_machine.problem.pure_noise(nb, local_device)
+ all_masks = all_input.new_full(all_input.size(), 1)
+ all_masks.reshape(all_masks.size(0), 4, -1)[:, :, 0] = 0
+
+ for input, masks in tqdm.tqdm(
+ zip(
+ all_input.split(args.physical_batch_size),
+ all_masks.split(args.physical_batch_size),
+ ),
+ dynamic_ncols=True,
+ desc="predict",
+ total=all_input.size(0) // args.physical_batch_size,
+ ):
+ changed = True
+ for it in range(args.diffusion_nb_iterations):
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(input)
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ output = dist.sample()
+
+ r = prioritized_rand(input != output)
+ mask_changes = (r <= args.diffusion_proba_corruption).long() * masks
+ update = (1 - mask_changes) * input + mask_changes * output
+ if update.equal(input):
+ break
+ else:
+ changed = changed & (update != input).max(dim=1).values
+ input[changed] = update[changed]
return input
######################################################################
-def batch_interleave(a, b, perm):
- return torch.cat([a, b])[perm].reshape(-1, args.physical_batch_size, a.size(1))
-
-
def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
if train:
label = "train"
args.c_quiz_multiplier,
)
- input_p, input_g = quizzes.to(local_device).chunk(2)
+ q1, q2 = quizzes.to(local_device).chunk(2)
+
imt_set = torch.cat(
- [IMT_batch_prediction(input_p, proba_hints=0.5), IMT_batch_generation(input_g)]
+ [IMT_batch_prediction(q1, proba_hints=0.5), IMT_batch_generation(q2)]
)
imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
# generate
- result = generate(model, 25, local_device=local_device).to("cpu")
+ result = generate(model, 150, local_device=local_device).to("cpu")
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,
f"culture_generation_{n_epoch}_{model.id}.png",