- logger(f"----------------------------------------------------------")
-
- for e in self.tensor2str(result[:10]):
- logger(f"test_after {e}")
-
- logger(f"----------------------------------------------------------")
-
- q_train_set = result[:, : nb_samples * 3]
- q_params = result[:, nb_samples * 3 + 1 :]
- error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)
+ q_train_set = result[:, : self.nb_samples_per_mlp * 3]
+ q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
+ error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)