Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 17 Sep 2024 07:45:36 +0000 (09:45 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 17 Sep 2024 07:45:36 +0000 (09:45 +0200)
main.py

diff --git a/main.py b/main.py
index 3339838..8054509 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -390,91 +390,6 @@ data_structures = [
 ######################################################################
 
 
-def model_proba_solutions(model, input, log_probas=False, reduce=True):
-    record = []
-
-    for x_0 in input.split(args.batch_size):
-        loss = 0
-
-        for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
-            mask_generate = quiz_machine.make_quiz_mask(
-                quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
-            )
-            logits = logits_hat_x_0_from_random_iteration(
-                model=model,
-                x_0=x_0,
-                mask_generate=mask_generate,
-                prompt_noise=args.prompt_noise,
-            )
-            loss_per_token = F.cross_entropy(
-                logits.transpose(1, 2), x_0, reduction="none"
-            )
-            if reduce:
-                loss += (loss_per_token * mask_generate).sum(dim=1)
-            else:
-                loss += loss_per_token * mask_generate
-
-        record.append(loss)
-
-    loss = torch.cat(record, dim=0)
-
-    if log_probas:
-        return -loss
-    else:
-        return (-loss).exp()
-
-
-######################################################################
-
-
-def batches(
-    quiz_machine,
-    nb,
-    data_structures,
-    local_device,
-    c_quizzes=None,
-    alien_quiz_machine=None,
-    desc=None,
-    batch_size=args.batch_size,
-):
-    c_quiz_bags = [] if c_quizzes is None else [c_quizzes.to("cpu")]
-
-    full_input, full_mask_generate, _ = quiz_machine.data_input(
-        nb,
-        c_quiz_bags,
-        data_structures=data_structures,
-        c_quiz_multiplier=args.c_quiz_multiplier,
-    )
-
-    src = zip(
-        full_input.split(batch_size),
-        full_mask_generate.split(batch_size),
-    )
-
-    if desc is not None:
-        src = tqdm.tqdm(
-            src,
-            dynamic_ncols=True,
-            desc=desc,
-            total=full_input.size(0) // batch_size,
-        )
-
-    for input, mask_generate in src:
-        yield (
-            input.to(local_device),
-            mask_generate.to(local_device),
-        )
-
-
-def NTC_channel_cat(*x):
-    return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
-
-
-def NTC_masked_cross_entropy(output, targets, mask):
-    loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
-    return (loss_per_token * mask).mean()
-
-
 def masked_cross_entropy(output, targets, masks):
     loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
     return (loss_per_token * masks).sum() / masks.expand_as(loss_per_token).sum()
@@ -482,165 +397,7 @@ def masked_cross_entropy(output, targets, masks):
 
 ######################################################################
 
-
-def run_test(
-    model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device, prefix=None
-):
-    if prefix is None:
-        prefix = ""
-    else:
-        prefix = prefix + "_"
-
-    with torch.autograd.no_grad():
-        model.eval().to(local_device)
-
-        # Compute the loss
-
-        nb_test_samples, acc_test_loss = 0, 0.0
-
-        for x_0, mask_generate in batches(
-            quiz_machine,
-            args.nb_test_samples,
-            data_structures,
-            local_device,
-            c_quizzes=c_quizzes,
-            desc="test",
-        ):
-            logits = diffuser.logits_hat_x_0_from_random_iteration(
-                model=model,
-                x_0=x_0,
-                mask_generate=mask_generate,
-            )
-            loss = masked_cross_entropy(logits, x_0, mask_generate)
-            acc_test_loss += loss.item() * x_0.size(0)
-            nb_test_samples += x_0.size(0)
-
-        log_string(
-            f"{prefix}test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}"
-        )
-
-        # Compute the accuracy and save some images
-
-        nb_correct, nb_total, record_d, record_nd = 0, 0, [], []
-
-        for x_0, mask_generate in batches(
-            quiz_machine,
-            args.nb_test_samples,
-            data_structures,
-            local_device,
-            c_quizzes=c_quizzes,
-            desc="test",
-        ):
-            result = diffuser.generate(model, (1 - mask_generate) * x_0, mask_generate)
-            correct = (result == x_0).min(dim=1).values.long()
-            predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[
-                :, :, 1
-            ]
-            d = predicted_parts.sum(dim=-1) == 1
-            correct = (2 * correct - 1) * d.long()
-            nb_correct += (correct == 1).long().sum()
-            nb_total += (correct != 0).long().sum()
-            correct_parts = predicted_parts * correct[:, None]
-            record_d.append((result[d], predicted_parts[d], correct_parts[d]))
-            nd = d == False
-            record_nd.append((result[nd], predicted_parts[nd], correct_parts[nd]))
-
-        log_string(
-            f"{prefix}test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
-        )
-
-        # Save some images
-
-        for f, record in [("prediction", record_d), ("generation", record_nd)]:
-            result, predicted_parts, correct_parts = bag_to_tensors(record)
-
-            filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
-
-            quiz_machine.problem.save_quizzes_as_image(
-                args.result_dir,
-                filename,
-                quizzes=result[:128],
-                predicted_parts=predicted_parts[:128],
-                correct_parts=correct_parts[:128],
-            )
-
-            log_string(f"wrote {filename}")
-
-        return nb_correct / nb_total
-
-
-######################################################################
-
-
-def one_epoch_(model, n_epoch, c_quizzes, local_device=main_device):
-    model.train().to(local_device)
-    optimizer_to(model.optimizer, local_device)
-
-    nb_train_samples, acc_train_loss = 0, 0.0
-
-    # scaler = torch.amp.GradScaler("cuda")
-
-    for x_0, mask_generate in batches(
-        quiz_machine,
-        args.nb_train_samples,
-        data_structures,
-        local_device,
-        c_quizzes=c_quizzes,
-        desc="training",
-    ):
-        x_0 = x_0.to(local_device)
-        mask_generate = mask_generate.to(local_device)
-
-        if nb_train_samples % args.batch_size == 0:
-            model.optimizer.zero_grad()
-
-        nb_hints = torch.randint(2, (x_0.size(0),), device=x_0.device) * args.nb_hints
-
-        with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
-            logits = diffuser.logits_hat_x_0_from_random_iteration(
-                model=model,
-                x_0=x_0,
-                mask_generate=mask_generate,
-                prompt_noise=args.prompt_noise,
-                nb_hints=nb_hints,
-            )
-
-        loss = masked_cross_entropy(logits, x_0, mask_generate)
-        acc_train_loss += loss.item() * x_0.size(0)
-        nb_train_samples += x_0.size(0)
-
-        loss.backward()
-
-        if nb_train_samples % args.batch_size == 0:
-            model.optimizer.step()
-
-        # scaler.scale(loss).backward()
-
-        # if nb_train_samples % args.batch_size == 0:
-        # scaler.step(model.optimizer)
-
-        # scaler.update()
-
-    log_string(
-        f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
-    )
-
-    model.test_accuracy = run_test(
-        model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device
-    )
-
-    if args.nb_test_alien_samples > 0:
-        run_test(
-            model,
-            alien_quiz_machine,
-            n_epoch,
-            c_quizzes=None,
-            local_device=local_device,
-            prefix="alien",
-        )
-
-
-######################################################################
+# IMT for input / masks / target
 
 
 def IMT_batch_prediction(input, proba_hints=0.0):
@@ -687,8 +444,6 @@ def predict(model, imt_set, local_device=main_device):
 
 ######################################################################
 
-# IMT for input / masks / target
-
 
 def IMT_batch_generation(input):
     nb = input.size(0)
@@ -723,25 +478,34 @@ def prioritized_rand(low):
 
 
 def generate(model, nb, local_device=main_device):
-    input = quiz_machine.problem.pure_noise(nb, local_device)
-    masks = input.new_full(input.size(), 1)
-    masks.reshape(masks.size(0), 4, -1)[:, :, 0] = 0
-
-    changed = True
-    for it in range(args.diffusion_nb_iterations):
-        with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
-            logits = model(input)
-        dist = torch.distributions.categorical.Categorical(logits=logits)
-        output = dist.sample()
-
-        r = prioritized_rand(input != output)
-        mask_changes = (r <= args.diffusion_proba_corruption).long() * masks
-        update = (1 - mask_changes) * input + mask_changes * output
-        if update.equal(input):
-            break
-        else:
-            changed = changed & (update != input).max(dim=1).values
-            input[changed] = update[changed]
+    all_input = quiz_machine.problem.pure_noise(nb, local_device)
+    all_masks = all_input.new_full(all_input.size(), 1)
+    all_masks.reshape(all_masks.size(0), 4, -1)[:, :, 0] = 0
+
+    for input, masks in tqdm.tqdm(
+        zip(
+            all_input.split(args.physical_batch_size),
+            all_masks.split(args.physical_batch_size),
+        ),
+        dynamic_ncols=True,
+        desc="predict",
+        total=all_input.size(0) // args.physical_batch_size,
+    ):
+        changed = True
+        for it in range(args.diffusion_nb_iterations):
+            with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+                logits = model(input)
+            dist = torch.distributions.categorical.Categorical(logits=logits)
+            output = dist.sample()
+
+            r = prioritized_rand(input != output)
+            mask_changes = (r <= args.diffusion_proba_corruption).long() * masks
+            update = (1 - mask_changes) * input + mask_changes * output
+            if update.equal(input):
+                break
+            else:
+                changed = changed & (update != input).max(dim=1).values
+                input[changed] = update[changed]
 
     return input
 
@@ -749,10 +513,6 @@ def generate(model, nb, local_device=main_device):
 ######################################################################
 
 
-def batch_interleave(a, b, perm):
-    return torch.cat([a, b])[perm].reshape(-1, args.physical_batch_size, a.size(1))
-
-
 def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
     if train:
         label = "train"
@@ -770,9 +530,10 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
         args.c_quiz_multiplier,
     )
 
-    input_p, input_g = quizzes.to(local_device).chunk(2)
+    q1, q2 = quizzes.to(local_device).chunk(2)
+
     imt_set = torch.cat(
-        [IMT_batch_prediction(input_p, proba_hints=0.5), IMT_batch_generation(input_g)]
+        [IMT_batch_prediction(q1, proba_hints=0.5), IMT_batch_generation(q2)]
     )
 
     imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
@@ -831,7 +592,7 @@ def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
 
     # generate
 
-    result = generate(model, 25, local_device=local_device).to("cpu")
+    result = generate(model, 150, local_device=local_device).to("cpu")
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir,
         f"culture_generation_{n_epoch}_{model.id}.png",