Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 17 Sep 2024 12:05:17 +0000 (14:05 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 17 Sep 2024 12:05:17 +0000 (14:05 +0200)
main.py

diff --git a/main.py b/main.py
index ce86a76..a353868 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -431,8 +431,10 @@ def predict(model, imt_set, local_device=main_device):
         desc="predict",
         total=imt_set.size(0) // args.physical_batch_size,
     ):
-        masks = imt[:, 1]
-        imt = imt * (1 - masks[:, None])  # paranoia
+        # some paranoia
+        imt = imt.clone()
+        imt[:, 0] = imt[:, 0] * (1 - imt[:1])
+
         with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
             logits = model(imt[:, 0] * 2 + imt[:, 1])
         dist = torch.distributions.categorical.Categorical(logits=logits)
@@ -494,7 +496,7 @@ def generate(model, nb, local_device=main_device):
         changed = True
         for it in range(args.diffusion_nb_iterations):
             with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
-                logits = model(input)
+                logits = model(input * 2 + masks)
             dist = torch.distributions.categorical.Categorical(logits=logits)
             output = dist.sample()
 
@@ -507,7 +509,7 @@ def generate(model, nb, local_device=main_device):
                 changed = changed & (update != input).max(dim=1).values
                 input[changed] = update[changed]
 
-    return input
+    return all_input
 
 
 ######################################################################
@@ -563,7 +565,10 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
     log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
 
 
-def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
+######################################################################
+
+
+def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     # train
 
     one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=True)
@@ -575,11 +580,13 @@ def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     imt_set = IMT_batch_prediction(quizzes.to(local_device))
     result = predict(model, imt_set, local_device=local_device).to("cpu")
     masks = imt_set[:, 1].to("cpu")
+
     correct = (quizzes == result).min(dim=1).values.long()
     correct_parts = (2 * correct - 1)[:, None] * masks.reshape(masks.size(0), 4, -1)[
         :, :, 1
     ]
     predicted_parts = correct_parts.abs()
+
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir,
         f"culture_prediction_{n_epoch}_{model.id}.png",
@@ -1063,7 +1070,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     # None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
 
     multithread_execution(
-        one_train_test_epoch,
+        one_complete_epoch,
         [(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)],
     )