- train_action_seq += nb_frame_codes
- self.train_input = torch.cat(
- (train_frame_seq[:, 0, :], train_action_seq, train_frame_seq[:, 1, :]), 1
- )
-
- test_frame_seq = test_frame_seq.reshape(test_frame_seq.size(0) // 2, 2, -1)
- test_action_seq += nb_frame_codes
- self.test_input = torch.cat(
- (test_frame_seq[:, 0, :], test_action_seq, test_frame_seq[:, 1, :]), 1
- )