Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 17 Sep 2024 07:30:35 +0000 (09:30 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 17 Sep 2024 07:30:35 +0000 (09:30 +0200)
main.py

diff --git a/main.py b/main.py
index e921ccd..3339838 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -475,6 +475,11 @@ def NTC_masked_cross_entropy(output, targets, mask):
     return (loss_per_token * mask).mean()
 
 
+def masked_cross_entropy(output, targets, masks):
+    loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
+    return (loss_per_token * masks).sum() / masks.expand_as(loss_per_token).sum()
+
+
 ######################################################################
 
 
@@ -506,7 +511,7 @@ def run_test(
                 x_0=x_0,
                 mask_generate=mask_generate,
             )
-            loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
+            loss = masked_cross_entropy(logits, x_0, mask_generate)
             acc_test_loss += loss.item() * x_0.size(0)
             nb_test_samples += x_0.size(0)
 
@@ -600,7 +605,7 @@ def one_epoch_(model, n_epoch, c_quizzes, local_device=main_device):
                 nb_hints=nb_hints,
             )
 
-        loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
+        loss = masked_cross_entropy(logits, x_0, mask_generate)
         acc_train_loss += loss.item() * x_0.size(0)
         nb_train_samples += x_0.size(0)
 
@@ -638,47 +643,43 @@ def one_epoch_(model, n_epoch, c_quizzes, local_device=main_device):
 ######################################################################
 
 
-def batch_prediction(input, proba_hints=0.0):
+def IMT_batch_prediction(input, proba_hints=0.0):
     nb = input.size(0)
-    mask = input.new_zeros(input.size())
-    u = F.one_hot(torch.randint(4, (nb,), device=mask.device), num_classes=4)
-    mask.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
+    masks = input.new_zeros(input.size())
+    u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
+    masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
 
     if proba_hints > 0:
-        h = torch.rand(input.size(), device=input.device) * mask
+        h = torch.rand(input.size(), device=input.device) * masks
         mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints
         v = torch.rand(nb, device=input.device)[:, None]
         mask_hints = mask_hints * (v < proba_hints).long()
-        mask = (1 - mask_hints) * mask
+        masks = (1 - mask_hints) * masks
 
     # noise = quiz_machine.problem.pure_noise(nb, input.device)
     targets = input
-    input = (1 - mask) * targets  # + mask * noise
+    input = (1 - masks) * targets  # + masks * noise
 
-    return input, targets, mask
+    return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
 
-def predict(model, input, targets, mask, local_device=main_device):
+def predict(model, imt_set, local_device=main_device):
     model.eval().to(local_device)
 
-    input_batches = input.reshape(-1, args.physical_batch_size, input.size(1))
-    targets_batches = targets.reshape(-1, args.physical_batch_size, targets.size(1))
-    mask_batches = mask.reshape(-1, args.physical_batch_size, mask.size(1))
-
     record = []
 
-    for input, targets, mask in tqdm.tqdm(
-        zip(input_batches, targets_batches, mask_batches),
+    for imt in tqdm.tqdm(
+        imt_set.split(args.physical_batch_size),
         dynamic_ncols=True,
         desc="predict",
-        total=input.size(0) // args.physical_batch_size,
+        total=imt_set.size(0) // args.physical_batch_size,
     ):
-        # noise = quiz_machine.problem.pure_noise(input.size(0), input.device)
-        input = (1 - mask) * input  # + mask * noise
+        masks = imt[:, 1]
+        imt = imt * (1 - masks[:, None])  # paranoia
         with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
-            logits = model(NTC_channel_cat(input, mask))
+            logits = model(imt[:, 0] * 2 + imt[:, 1])
         dist = torch.distributions.categorical.Categorical(logits=logits)
-        result = (1 - mask) * input + mask * dist.sample()
+        result = (1 - masks) * imt[:, 0] + masks * dist.sample()
         record.append(result)
 
     return torch.cat(record)
@@ -686,8 +687,10 @@ def predict(model, input, targets, mask, local_device=main_device):
 
 ######################################################################
 
+# IMT for input / masks / target
+
 
-def batch_generation(input):
+def IMT_batch_generation(input):
     nb = input.size(0)
     probs_iterations = 0.1 ** torch.linspace(
         0, 1, args.diffusion_nb_iterations, device=input.device
@@ -704,10 +707,10 @@ def batch_generation(input):
 
     targets = input
     input = (1 - mask_erased) * input + mask_erased * noise
-    mask = input.new_full(input.size(), 1)
-    mask.reshape(mask.size(0), 4, -1)[:, :, 0] = 0
+    masks = input.new_full(input.size(), 1)
+    masks.reshape(masks.size(0), 4, -1)[:, :, 0] = 0
 
-    return input, targets, mask
+    return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
 
 def prioritized_rand(low):
@@ -721,18 +724,18 @@ def prioritized_rand(low):
 
 def generate(model, nb, local_device=main_device):
     input = quiz_machine.problem.pure_noise(nb, local_device)
-    mask = input.new_full(input.size(), 1)
-    mask.reshape(mask.size(0), 4, -1)[:, :, 0] = 0
+    masks = input.new_full(input.size(), 1)
+    masks.reshape(masks.size(0), 4, -1)[:, :, 0] = 0
 
     changed = True
     for it in range(args.diffusion_nb_iterations):
         with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
-            logits = model(NTC_channel_cat(input, mask))
+            logits = model(input)
         dist = torch.distributions.categorical.Categorical(logits=logits)
         output = dist.sample()
 
         r = prioritized_rand(input != output)
-        mask_changes = (r <= args.diffusion_proba_corruption).long() * mask
+        mask_changes = (r <= args.diffusion_proba_corruption).long() * masks
         update = (1 - mask_changes) * input + mask_changes * output
         if update.equal(input):
             break
@@ -768,16 +771,14 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
     )
 
     input_p, input_g = quizzes.to(local_device).chunk(2)
-    input_p, targets_p, mask_p = batch_prediction(input_p, proba_hints=0.5)
-    input_g, targets_g, mask_g = batch_generation(input_g)
+    imt_set = torch.cat(
+        [IMT_batch_prediction(input_p, proba_hints=0.5), IMT_batch_generation(input_g)]
+    )
 
-    perm = torch.randperm(quizzes.size(0), device=local_device)
-    input_batches = batch_interleave(input_p, input_g, perm)
-    targets_batches = batch_interleave(targets_p, targets_g, perm)
-    mask_batches = batch_interleave(mask_p, mask_g, perm)
+    imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
 
-    for input, targets, mask in tqdm.tqdm(
-        zip(input_batches, targets_batches, mask_batches),
+    for imt in tqdm.tqdm(
+        imt_set.split(args.physical_batch_size),
         dynamic_ncols=True,
         desc=label,
         total=quizzes.size(0) // args.physical_batch_size,
@@ -786,11 +787,11 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
             model.optimizer.zero_grad()
 
         with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
-            logits = model(NTC_channel_cat(input, mask))
+            logits = model(imt[:, 0] * 2 + imt[:, 1])
 
-        loss = NTC_masked_cross_entropy(logits, targets, mask)
-        acc_loss += loss.item() * input.size(0)
-        nb_samples += input.size(0)
+        loss = masked_cross_entropy(logits, targets=imt[:, 2], masks=imt[:, 1])
+        acc_loss += loss.item() * imt.size(0)
+        nb_samples += imt.size(0)
 
         if train:
             loss.backward()
@@ -810,11 +811,11 @@ def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     # predict
 
     quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier)
-    input, targets, mask = batch_prediction(quizzes.to(local_device))
-    result = predict(model, input, targets, mask, local_device=local_device).to("cpu")
-    mask = mask.to("cpu")
+    imt_set = IMT_batch_prediction(quizzes.to(local_device))
+    result = predict(model, imt_set, local_device=local_device).to("cpu")
+    masks = imt_set[:, 1].to("cpu")
     correct = (quizzes == result).min(dim=1).values.long()
-    correct_parts = (2 * correct - 1)[:, None] * mask.reshape(mask.size(0), 4, -1)[
+    correct_parts = (2 * correct - 1)[:, None] * masks.reshape(masks.size(0), 4, -1)[
         :, :, 1
     ]
     predicted_parts = correct_parts.abs()
@@ -845,9 +846,8 @@ import attae
 models = []
 
 for i in range(args.nb_models):
-    # model = MyAttentionAE(
-    model = attae.MaskedAttentionAE(
-        vocabulary_size=vocabulary_size,
+    model = attae.AttentionAE(
+        vocabulary_size=vocabulary_size * 2,
         dim_model=args.dim_model,
         dim_keys=args.dim_keys,
         dim_hidden=args.dim_hidden,