test_input = dequantize(q_test_input, -1, 1)
if save_as_examples:
- a = 2 * torch.arange(nb_quantization_levels).float() / (nb_quantization_levels - 1) - 1
- xf = torch.cat([a[:,None,None].expand(nb_quantization_levels, nb_quantization_levels,1),
- a[None,:,None].expand(nb_quantization_levels, nb_quantization_levels,1)], 2)
- xf = xf.reshape(1,-1,2).expand(min(q_train_input.size(0),10),-1,-1)
+ a = (
+ 2
+ * torch.arange(nb_quantization_levels).float()
+ / (nb_quantization_levels - 1)
+ - 1
+ )
+ xf = torch.cat(
+ [
+ a[:, None, None].expand(
+ nb_quantization_levels, nb_quantization_levels, 1
+ ),
+ a[None, :, None].expand(
+ nb_quantization_levels, nb_quantization_levels, 1
+ ),
+ ],
+ 2,
+ )
+ xf = xf.reshape(1, -1, 2).expand(min(q_train_input.size(0), 10), -1, -1)
print(f"{xf.size()=} {x.size()=}")
yf = (
(
- (xf[:, None, :, 0] >= rec_support[:xf.size(0), :, None, 0]).long()
- * (xf[:, None, :, 0] <= rec_support[:xf.size(0), :, None, 1]).long()
- * (xf[:, None, :, 1] >= rec_support[:xf.size(0), :, None, 2]).long()
- * (xf[:, None, :, 1] <= rec_support[:xf.size(0), :, None, 3]).long()
+ (xf[:, None, :, 0] >= rec_support[: xf.size(0), :, None, 0]).long()
+ * (xf[:, None, :, 0] <= rec_support[: xf.size(0), :, None, 1]).long()
+ * (xf[:, None, :, 1] >= rec_support[: xf.size(0), :, None, 2]).long()
+ * (xf[:, None, :, 1] <= rec_support[: xf.size(0), :, None, 3]).long()
)
.max(dim=1)
.values
)
- full_input, full_targets = xf,yf
+ full_input, full_targets = xf, yf
q_full_input = quantize(full_input, -1, 1)
full_input = dequantize(q_full_input, -1, 1)
def evaluate_q_params(
- q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024,
- save_as_examples=False,
+ q_params,
+ q_set,
+ batch_size=25,
+ device=torch.device("cpu"),
+ nb_mlps_per_batch=1024,
+ save_as_examples=False,
):
errors = []
nb_mlps = q_params.size(0)