X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=qmlp.py;h=abebfc17e1befec6e1e612b1197f47909df083cc;hb=fdc61b7e50e029aac58b10f377acdce549532f84;hp=572cde1d5643bc45700090e66715d724f8a3259d;hpb=4aa7e109b4c712643cdddc2480b66d8799f71d3f;p=picoclvr.git diff --git a/qmlp.py b/qmlp.py index 572cde1..abebfc1 100755 --- a/qmlp.py +++ b/qmlp.py @@ -53,12 +53,14 @@ def generate_sets_and_params( batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device ) + nb_rec = 8 + nb_values = 2 # more increases the min-max gap + + rec_support = torch.empty(batch_nb_mlps, nb_rec, 4, device=device) + while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1: i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1 nb = i.sum() - - nb_rec = 8 - nb_values = 2 # more increases the min-max gap support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1 support = support.sort(-1).values support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4) @@ -75,7 +77,7 @@ def generate_sets_and_params( .values ) - data_input[i], data_targets[i] = x, y + data_input[i], data_targets[i], rec_support[i] = x, y, support train_input, train_targets = ( data_input[:, :nb_samples], @@ -85,15 +87,53 @@ def generate_sets_and_params( q_train_input = quantize(train_input, -1, 1) train_input = dequantize(q_train_input, -1, 1) - train_targets = train_targets q_test_input = quantize(test_input, -1, 1) test_input = dequantize(q_test_input, -1, 1) - test_targets = test_targets if save_as_examples: - for k in range(q_train_input.size(0)): - with open(f"example_{k:04d}.dat", "w") as f: + a = ( + 2 + * torch.arange(nb_quantization_levels).float() + / (nb_quantization_levels - 1) + - 1 + ) + xf = torch.cat( + [ + a[:, None, None].expand( + nb_quantization_levels, nb_quantization_levels, 1 + ), + a[None, :, None].expand( + nb_quantization_levels, nb_quantization_levels, 1 + ), + ], + 2, + ) + xf = xf.reshape(1, -1, 2).expand(min(q_train_input.size(0), 10), -1, -1) + print(f"{xf.size()=} {x.size()=}") + yf = ( + ( + (xf[:, None, :, 0] >= rec_support[: xf.size(0), :, None, 0]).long() + * (xf[:, None, :, 0] <= rec_support[: xf.size(0), :, None, 1]).long() + * (xf[:, None, :, 1] >= rec_support[: xf.size(0), :, None, 2]).long() + * (xf[:, None, :, 1] <= rec_support[: xf.size(0), :, None, 3]).long() + ) + .max(dim=1) + .values + ) + + full_input, full_targets = xf, yf + + q_full_input = quantize(full_input, -1, 1) + full_input = dequantize(q_full_input, -1, 1) + + for k in range(q_full_input[:10].size(0)): + with open(f"example_full_{k:04d}.dat", "w") as f: + for u, c in zip(full_input[k], full_targets[k]): + f.write(f"{c} {u[0].item()} {u[1].item()}\n") + + for k in range(q_train_input[:10].size(0)): + with open(f"example_train_{k:04d}.dat", "w") as f: for u, c in zip(train_input[k], train_targets[k]): f.write(f"{c} {u[0].item()} {u[1].item()}\n") @@ -182,8 +222,12 @@ def generate_sets_and_params( def evaluate_q_params( - q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024, - save_as_examples=False, + q_params, + q_set, + batch_size=25, + device=torch.device("cpu"), + nb_mlps_per_batch=1024, + save_as_examples=False, ): errors = [] nb_mlps = q_params.size(0) @@ -293,7 +337,7 @@ def generate_sequence_and_test_set( if __name__ == "__main__": import time - batch_nb_mlps, nb_samples = 128, 2500 + batch_nb_mlps, nb_samples = 128, 250 generate_sets_and_params( batch_nb_mlps=10,