Oups
[picoclvr.git] / qmlp.py
diff --git a/qmlp.py b/qmlp.py
index 572cde1..abebfc1 100755 (executable)
--- a/qmlp.py
+++ b/qmlp.py
@@ -53,12 +53,14 @@ def generate_sets_and_params(
         batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
     )
 
+    nb_rec = 8
+    nb_values = 2  # more increases the min-max gap
+
+    rec_support = torch.empty(batch_nb_mlps, nb_rec, 4, device=device)
+
     while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
         i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
         nb = i.sum()
-
-        nb_rec = 8
-        nb_values = 2  # more increases the min-max gap
         support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1
         support = support.sort(-1).values
         support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4)
@@ -75,7 +77,7 @@ def generate_sets_and_params(
             .values
         )
 
-        data_input[i], data_targets[i] = x, y
+        data_input[i], data_targets[i], rec_support[i] = x, y, support
 
     train_input, train_targets = (
         data_input[:, :nb_samples],
@@ -85,15 +87,53 @@ def generate_sets_and_params(
 
     q_train_input = quantize(train_input, -1, 1)
     train_input = dequantize(q_train_input, -1, 1)
-    train_targets = train_targets
 
     q_test_input = quantize(test_input, -1, 1)
     test_input = dequantize(q_test_input, -1, 1)
-    test_targets = test_targets
 
     if save_as_examples:
-        for k in range(q_train_input.size(0)):
-            with open(f"example_{k:04d}.dat", "w") as f:
+        a = (
+            2
+            * torch.arange(nb_quantization_levels).float()
+            / (nb_quantization_levels - 1)
+            - 1
+        )
+        xf = torch.cat(
+            [
+                a[:, None, None].expand(
+                    nb_quantization_levels, nb_quantization_levels, 1
+                ),
+                a[None, :, None].expand(
+                    nb_quantization_levels, nb_quantization_levels, 1
+                ),
+            ],
+            2,
+        )
+        xf = xf.reshape(1, -1, 2).expand(min(q_train_input.size(0), 10), -1, -1)
+        print(f"{xf.size()=} {x.size()=}")
+        yf = (
+            (
+                (xf[:, None, :, 0] >= rec_support[: xf.size(0), :, None, 0]).long()
+                * (xf[:, None, :, 0] <= rec_support[: xf.size(0), :, None, 1]).long()
+                * (xf[:, None, :, 1] >= rec_support[: xf.size(0), :, None, 2]).long()
+                * (xf[:, None, :, 1] <= rec_support[: xf.size(0), :, None, 3]).long()
+            )
+            .max(dim=1)
+            .values
+        )
+
+        full_input, full_targets = xf, yf
+
+        q_full_input = quantize(full_input, -1, 1)
+        full_input = dequantize(q_full_input, -1, 1)
+
+        for k in range(q_full_input[:10].size(0)):
+            with open(f"example_full_{k:04d}.dat", "w") as f:
+                for u, c in zip(full_input[k], full_targets[k]):
+                    f.write(f"{c} {u[0].item()} {u[1].item()}\n")
+
+        for k in range(q_train_input[:10].size(0)):
+            with open(f"example_train_{k:04d}.dat", "w") as f:
                 for u, c in zip(train_input[k], train_targets[k]):
                     f.write(f"{c} {u[0].item()} {u[1].item()}\n")
 
@@ -182,8 +222,12 @@ def generate_sets_and_params(
 
 
 def evaluate_q_params(
-        q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024,
-        save_as_examples=False,
+    q_params,
+    q_set,
+    batch_size=25,
+    device=torch.device("cpu"),
+    nb_mlps_per_batch=1024,
+    save_as_examples=False,
 ):
     errors = []
     nb_mlps = q_params.size(0)
@@ -293,7 +337,7 @@ def generate_sequence_and_test_set(
 if __name__ == "__main__":
     import time
 
-    batch_nb_mlps, nb_samples = 128, 2500
+    batch_nb_mlps, nb_samples = 128, 250
 
     generate_sets_and_params(
         batch_nb_mlps=10,