X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=qmlp.py;h=abebfc17e1befec6e1e612b1197f47909df083cc;hb=HEAD;hp=b58598a8a04cc0b2d9947a774fc3047967932fcd;hpb=3e3bf1003aa0ecbf7d38b7b0c289fbe1cfa3101b;p=picoclvr.git diff --git a/qmlp.py b/qmlp.py index b58598a..abebfc1 100755 --- a/qmlp.py +++ b/qmlp.py @@ -92,23 +92,37 @@ def generate_sets_and_params( test_input = dequantize(q_test_input, -1, 1) if save_as_examples: - a = 2 * torch.arange(nb_quantization_levels).float() / (nb_quantization_levels - 1) - 1 - xf = torch.cat([a[:,None,None].expand(nb_quantization_levels, nb_quantization_levels,1), - a[None,:,None].expand(nb_quantization_levels, nb_quantization_levels,1)], 2) - xf = xf.reshape(1,-1,2).expand(min(q_train_input.size(0),10),-1,-1) + a = ( + 2 + * torch.arange(nb_quantization_levels).float() + / (nb_quantization_levels - 1) + - 1 + ) + xf = torch.cat( + [ + a[:, None, None].expand( + nb_quantization_levels, nb_quantization_levels, 1 + ), + a[None, :, None].expand( + nb_quantization_levels, nb_quantization_levels, 1 + ), + ], + 2, + ) + xf = xf.reshape(1, -1, 2).expand(min(q_train_input.size(0), 10), -1, -1) print(f"{xf.size()=} {x.size()=}") yf = ( ( - (xf[:, None, :, 0] >= rec_support[:xf.size(0), :, None, 0]).long() - * (xf[:, None, :, 0] <= rec_support[:xf.size(0), :, None, 1]).long() - * (xf[:, None, :, 1] >= rec_support[:xf.size(0), :, None, 2]).long() - * (xf[:, None, :, 1] <= rec_support[:xf.size(0), :, None, 3]).long() + (xf[:, None, :, 0] >= rec_support[: xf.size(0), :, None, 0]).long() + * (xf[:, None, :, 0] <= rec_support[: xf.size(0), :, None, 1]).long() + * (xf[:, None, :, 1] >= rec_support[: xf.size(0), :, None, 2]).long() + * (xf[:, None, :, 1] <= rec_support[: xf.size(0), :, None, 3]).long() ) .max(dim=1) .values ) - full_input, full_targets = xf,yf + full_input, full_targets = xf, yf q_full_input = quantize(full_input, -1, 1) full_input = dequantize(q_full_input, -1, 1) @@ -208,8 +222,12 @@ def generate_sets_and_params( def evaluate_q_params( - q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024, - save_as_examples=False, + q_params, + q_set, + batch_size=25, + device=torch.device("cpu"), + nb_mlps_per_batch=1024, + save_as_examples=False, ): errors = [] nb_mlps = q_params.size(0)