return (loss_per_token * mask).mean()
+def masked_cross_entropy(output, targets, masks):
+ loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
+ return (loss_per_token * masks).sum() / masks.expand_as(loss_per_token).sum()
+
+
######################################################################
x_0=x_0,
mask_generate=mask_generate,
)
- loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
+ loss = masked_cross_entropy(logits, x_0, mask_generate)
acc_test_loss += loss.item() * x_0.size(0)
nb_test_samples += x_0.size(0)
nb_hints=nb_hints,
)
- loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
+ loss = masked_cross_entropy(logits, x_0, mask_generate)
acc_train_loss += loss.item() * x_0.size(0)
nb_train_samples += x_0.size(0)
######################################################################
-def batch_prediction(input, proba_hints=0.0):
+def IMT_batch_prediction(input, proba_hints=0.0):
nb = input.size(0)
- mask = input.new_zeros(input.size())
- u = F.one_hot(torch.randint(4, (nb,), device=mask.device), num_classes=4)
- mask.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
+ masks = input.new_zeros(input.size())
+ u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
+ masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
if proba_hints > 0:
- h = torch.rand(input.size(), device=input.device) * mask
+ h = torch.rand(input.size(), device=input.device) * masks
mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints
v = torch.rand(nb, device=input.device)[:, None]
mask_hints = mask_hints * (v < proba_hints).long()
- mask = (1 - mask_hints) * mask
+ masks = (1 - mask_hints) * masks
# noise = quiz_machine.problem.pure_noise(nb, input.device)
targets = input
- input = (1 - mask) * targets # + mask * noise
+ input = (1 - masks) * targets # + masks * noise
- return input, targets, mask
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
-def predict(model, input, targets, mask, local_device=main_device):
+def predict(model, imt_set, local_device=main_device):
model.eval().to(local_device)
- input_batches = input.reshape(-1, args.physical_batch_size, input.size(1))
- targets_batches = targets.reshape(-1, args.physical_batch_size, targets.size(1))
- mask_batches = mask.reshape(-1, args.physical_batch_size, mask.size(1))
-
record = []
- for input, targets, mask in tqdm.tqdm(
- zip(input_batches, targets_batches, mask_batches),
+ for imt in tqdm.tqdm(
+ imt_set.split(args.physical_batch_size),
dynamic_ncols=True,
desc="predict",
- total=input.size(0) // args.physical_batch_size,
+ total=imt_set.size(0) // args.physical_batch_size,
):
- # noise = quiz_machine.problem.pure_noise(input.size(0), input.device)
- input = (1 - mask) * input # + mask * noise
+ masks = imt[:, 1]
+ imt = imt * (1 - masks[:, None]) # paranoia
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
- logits = model(NTC_channel_cat(input, mask))
+ logits = model(imt[:, 0] * 2 + imt[:, 1])
dist = torch.distributions.categorical.Categorical(logits=logits)
- result = (1 - mask) * input + mask * dist.sample()
+ result = (1 - masks) * imt[:, 0] + masks * dist.sample()
record.append(result)
return torch.cat(record)
######################################################################
+# IMT for input / masks / target
+
-def batch_generation(input):
+def IMT_batch_generation(input):
nb = input.size(0)
probs_iterations = 0.1 ** torch.linspace(
0, 1, args.diffusion_nb_iterations, device=input.device
targets = input
input = (1 - mask_erased) * input + mask_erased * noise
- mask = input.new_full(input.size(), 1)
- mask.reshape(mask.size(0), 4, -1)[:, :, 0] = 0
+ masks = input.new_full(input.size(), 1)
+ masks.reshape(masks.size(0), 4, -1)[:, :, 0] = 0
- return input, targets, mask
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
def prioritized_rand(low):
def generate(model, nb, local_device=main_device):
input = quiz_machine.problem.pure_noise(nb, local_device)
- mask = input.new_full(input.size(), 1)
- mask.reshape(mask.size(0), 4, -1)[:, :, 0] = 0
+ masks = input.new_full(input.size(), 1)
+ masks.reshape(masks.size(0), 4, -1)[:, :, 0] = 0
changed = True
for it in range(args.diffusion_nb_iterations):
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
- logits = model(NTC_channel_cat(input, mask))
+ logits = model(input)
dist = torch.distributions.categorical.Categorical(logits=logits)
output = dist.sample()
r = prioritized_rand(input != output)
- mask_changes = (r <= args.diffusion_proba_corruption).long() * mask
+ mask_changes = (r <= args.diffusion_proba_corruption).long() * masks
update = (1 - mask_changes) * input + mask_changes * output
if update.equal(input):
break
)
input_p, input_g = quizzes.to(local_device).chunk(2)
- input_p, targets_p, mask_p = batch_prediction(input_p, proba_hints=0.5)
- input_g, targets_g, mask_g = batch_generation(input_g)
+ imt_set = torch.cat(
+ [IMT_batch_prediction(input_p, proba_hints=0.5), IMT_batch_generation(input_g)]
+ )
- perm = torch.randperm(quizzes.size(0), device=local_device)
- input_batches = batch_interleave(input_p, input_g, perm)
- targets_batches = batch_interleave(targets_p, targets_g, perm)
- mask_batches = batch_interleave(mask_p, mask_g, perm)
+ imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
- for input, targets, mask in tqdm.tqdm(
- zip(input_batches, targets_batches, mask_batches),
+ for imt in tqdm.tqdm(
+ imt_set.split(args.physical_batch_size),
dynamic_ncols=True,
desc=label,
total=quizzes.size(0) // args.physical_batch_size,
model.optimizer.zero_grad()
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
- logits = model(NTC_channel_cat(input, mask))
+ logits = model(imt[:, 0] * 2 + imt[:, 1])
- loss = NTC_masked_cross_entropy(logits, targets, mask)
- acc_loss += loss.item() * input.size(0)
- nb_samples += input.size(0)
+ loss = masked_cross_entropy(logits, targets=imt[:, 2], masks=imt[:, 1])
+ acc_loss += loss.item() * imt.size(0)
+ nb_samples += imt.size(0)
if train:
loss.backward()
# predict
quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier)
- input, targets, mask = batch_prediction(quizzes.to(local_device))
- result = predict(model, input, targets, mask, local_device=local_device).to("cpu")
- mask = mask.to("cpu")
+ imt_set = IMT_batch_prediction(quizzes.to(local_device))
+ result = predict(model, imt_set, local_device=local_device).to("cpu")
+ masks = imt_set[:, 1].to("cpu")
correct = (quizzes == result).min(dim=1).values.long()
- correct_parts = (2 * correct - 1)[:, None] * mask.reshape(mask.size(0), 4, -1)[
+ correct_parts = (2 * correct - 1)[:, None] * masks.reshape(masks.size(0), 4, -1)[
:, :, 1
]
predicted_parts = correct_parts.abs()
models = []
for i in range(args.nb_models):
- # model = MyAttentionAE(
- model = attae.MaskedAttentionAE(
- vocabulary_size=vocabulary_size,
+ model = attae.AttentionAE(
+ vocabulary_size=vocabulary_size * 2,
dim_model=args.dim_model,
dim_keys=args.dim_keys,
dim_hidden=args.dim_hidden,