- seq_start = input[:, : self.len_frame_seq]
- seq_end = input[:, self.len_frame_seq + self.len_action_seq :]
- seq_predicted = result[:, self.len_frame_seq + self.len_action_seq :]
-
- result = torch.cat(
- (seq_start[:, None, :], seq_end[:, None, :], seq_predicted[:, None, :]), 1
- )
- result = result.reshape(-1, result.size(-1))
+ q_train_set = result[:, : self.nb_samples_per_mlp * 3]
+ q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
+ error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)