######################################################################
-# quad_order, quad_generate, quad_noise, quad_loss
-
-data_structures = [
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
- (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
- (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
- (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
- (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
-]
-
-
-######################################################################
-
-
-def masked_cross_entropy(output, targets, masks):
- loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
- return (loss_per_token * masks).mean()
-
-
-######################################################################
# Make args.nb_hints holes in the mask and copy the corresponding cell
# values from the target to the input
return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+# Make pixels from the available input (mask=0) noise with probability
+# args.prompt_noise
+
+
+def add_noise(imt_set):
+ input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
+ noise = quiz_machine.pure_noise(input.size(0), input.device)
+ change = (1 - masks) * (
+ torch.rand(input.size(), device=input.device) < args.prompt_noise
+ ).long()
+ input = (1 - change) * input + change * noise
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
# IMT for input / masks / target
return torch.cat(record)
-def predict_full(model, input, with_hints=False, local_device=main_device):
+def predict_full(model, input, with_perturbations=False, local_device=main_device):
input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
nb = input.size(0)
masks = input.new_zeros(input.size())
input = (1 - masks) * targets
imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
- if with_hints:
+ if with_perturbations:
imt_set = add_hints(imt_set)
+ imt_set = add_noise(imt_set)
result = ae_predict(model, imt_set, local_device=local_device, desc=None)
result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
q_p, q_g = quizzes.to(local_device).chunk(2)
b_p = batch_for_prediction_imt(q_p)
i = torch.rand(b_p.size(0)) < 0.5
+ b_p = add_noise(b_p)
b_p[i] = add_hints(b_p[i])
b_g = batch_for_generation_imt(q_g)
imt_set = torch.cat([b_p, b_g])
desc=label,
total=quizzes.size(0) // args.physical_batch_size,
):
+ input, masks, targets = imt[:, 0], imt[:, 1], imt[:, 2]
if train and nb_samples % args.batch_size == 0:
model.optimizer.zero_grad()
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
- logits = model(imt[:, 0] * 2 + imt[:, 1])
+ logits = model(input * 2 + masks)
- loss = masked_cross_entropy(logits, targets=imt[:, 2], masks=imt[:, 1])
+ loss_per_token = F.cross_entropy(
+ logits.transpose(1, 2), targets, reduction="none"
+ )
+ loss = (loss_per_token * masks).mean()
acc_loss += loss.item() * imt.size(0)
nb_samples += imt.size(0)
result = predict_full(
model=model,
input=quizzes,
- with_hints=True,
+ with_perturbations=True,
local_device=local_device,
)
nb_mistakes = (result != quizzes).long().sum(dim=1)