Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 07:11:36 +0000 (09:11 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 07:11:36 +0000 (09:11 +0200)
main.py

diff --git a/main.py b/main.py
index 4488a70..5f80fb5 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -415,17 +415,22 @@ def batch_prediction_imt(input, fraction_with_hints=0.0):
     return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
 
-def predict(model, imt_set, local_device=main_device):
+def predict(model, imt_set, local_device=main_device, desc="predict"):
     model.eval().to(local_device)
 
     record = []
 
-    for imt in tqdm.tqdm(
-        imt_set.split(args.physical_batch_size),
-        dynamic_ncols=True,
-        desc="predict",
-        total=imt_set.size(0) // args.physical_batch_size,
-    ):
+    src = imt_set.split(args.physical_batch_size)
+
+    if desc is not None:
+        src = tqdm.tqdm(
+            src,
+            dynamic_ncols=True,
+            desc=desc,
+            total=imt_set.size(0) // args.physical_batch_size,
+        )
+
+    for imt in src:
         # some paranoia
         imt = imt.clone()
         imt[:, 0] = imt[:, 0] * (1 - imt[:, 1])
@@ -452,7 +457,7 @@ def predict_full(model, input, fraction_with_hints=0.0, local_device=main_device
         [input[:, None], masks_with_hints[:, None], targets[:, None]], dim=1
     )
 
-    result = predict(model, imt_set, local_device=local_device)
+    result = predict(model, imt_set, local_device=local_device, desc=None)
     result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
 
     return result
@@ -491,21 +496,26 @@ def prioritized_rand(low):
     return y
 
 
-def generate(model, nb, local_device=main_device):
+def generate(model, nb, local_device=main_device, desc="generate"):
     model.eval().to(local_device)
 
     all_input = quiz_machine.pure_noise(nb, local_device)
     all_masks = all_input.new_full(all_input.size(), 1)
 
-    for input, masks in tqdm.tqdm(
-        zip(
-            all_input.split(args.physical_batch_size),
-            all_masks.split(args.physical_batch_size),
-        ),
-        dynamic_ncols=True,
-        desc="generate",
-        total=all_input.size(0) // args.physical_batch_size,
-    ):
+    src = zip(
+        all_input.split(args.physical_batch_size),
+        all_masks.split(args.physical_batch_size),
+    )
+
+    if desc is not None:
+        src = tqdm.tqdm(
+            src,
+            dynamic_ncols=True,
+            desc="generate",
+            total=all_input.size(0) // args.physical_batch_size,
+        )
+
+    for input, masks in src:
         changed = True
         for it in range(args.diffusion_nb_iterations):
             with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
@@ -877,12 +887,14 @@ def generate_c_quizzes(models, nb, local_device=main_device):
         generator_id = model.id
 
         c_quizzes = generate(
-            moel=copy_for_inference(model),
+            model=model,
             nb=args.physical_batch_size,
             local_device=local_device,
+            desc=None,
         )
 
         nb_correct, nb_wrong = 0, 0
+
         for i, model in enumerate(models):
             model = copy.deepcopy(model).to(local_device).eval()
             result = predict_full(model, c_quizzes, local_device=local_device)
@@ -897,6 +909,8 @@ def generate_c_quizzes(models, nb, local_device=main_device):
         nb_validated += to_keep.long().sum()
         record.append(c_quizzes[to_keep])
 
+        log_string(f"generate_c_quizzes {nb_validated}")
+
     return torch.cat(record)