batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
)
+ nb_rec = 8
+ nb_values = 2 # more increases the min-max gap
+
+ rec_support = torch.empty(batch_nb_mlps, nb_rec, 4, device=device)
+
while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
nb = i.sum()
-
- nb_rec = 8
- nb_values = 2 # more increases the min-max gap
support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1
support = support.sort(-1).values
support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4)
.values
)
- data_input[i], data_targets[i] = x, y
+ data_input[i], data_targets[i], rec_support[i] = x, y, support
train_input, train_targets = (
data_input[:, :nb_samples],
q_train_input = quantize(train_input, -1, 1)
train_input = dequantize(q_train_input, -1, 1)
- train_targets = train_targets
q_test_input = quantize(test_input, -1, 1)
test_input = dequantize(q_test_input, -1, 1)
- test_targets = test_targets
if save_as_examples:
- for k in range(q_train_input.size(0)):
- with open(f"example_{k:04d}.dat", "w") as f:
+ a = 2 * torch.arange(nb_quantization_levels).float() / (nb_quantization_levels - 1) - 1
+ xf = torch.cat([a[:,None,None].expand(nb_quantization_levels, nb_quantization_levels,1),
+ a[None,:,None].expand(nb_quantization_levels, nb_quantization_levels,1)], 2)
+ xf = xf.reshape(1,-1,2).expand(min(q_train_input.size(0),10),-1,-1)
+ print(f"{xf.size()=} {x.size()=}")
+ yf = (
+ (
+ (xf[:, None, :, 0] >= rec_support[:xf.size(0), :, None, 0]).long()
+ * (xf[:, None, :, 0] <= rec_support[:xf.size(0), :, None, 1]).long()
+ * (xf[:, None, :, 1] >= rec_support[:xf.size(0), :, None, 2]).long()
+ * (xf[:, None, :, 1] <= rec_support[:xf.size(0), :, None, 3]).long()
+ )
+ .max(dim=1)
+ .values
+ )
+
+ full_input, full_targets = xf,yf
+
+ q_full_input = quantize(full_input, -1, 1)
+ full_input = dequantize(q_full_input, -1, 1)
+
+ for k in range(q_full_input[:10].size(0)):
+ with open(f"example_full_{k:04d}.dat", "w") as f:
+ for u, c in zip(full_input[k], full_targets[k]):
+ f.write(f"{c} {u[0].item()} {u[1].item()}\n")
+
+ for k in range(q_train_input[:10].size(0)):
+ with open(f"example_train_{k:04d}.dat", "w") as f:
for u, c in zip(train_input[k], train_targets[k]):
f.write(f"{c} {u[0].item()} {u[1].item()}\n")
if __name__ == "__main__":
import time
- batch_nb_mlps, nb_samples = 128, 2500
+ batch_nb_mlps, nb_samples = 128, 250
generate_sets_and_params(
batch_nb_mlps=10,
self.train_input = seq[:nb_train_samples]
self.train_q_test_set = q_test_set[:nb_train_samples]
+ self.train_ref_test_errors = test_error[:nb_train_samples]
self.test_input = seq[nb_train_samples:]
self.test_q_test_set = q_test_set[nb_train_samples:]
- self.ref_test_errors = test_error
+ self.test_ref_test_errors = test_error[nb_train_samples:]
+
+ filename = os.path.join(result_dir, f"train_errors_ref.dat")
+ with open(filename, "w") as f:
+ for e in self.train_ref_test_errors:
+ f.write(f"{e}\n")
filename = os.path.join(result_dir, f"test_errors_ref.dat")
with open(filename, "w") as f:
- for e in self.ref_test_errors:
+ for e in self.test_ref_test_errors:
f.write(f"{e}\n")
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1