- (
- train_frames,
- train_action_seq,
- test_frames,
- test_action_seq,
- self.frame2seq,
- self.seq2frame,
- ) = world.create_data_and_processors(
- nb_train_samples,
- nb_test_samples,
- mode="first_last",
- nb_steps=30,
- nb_epochs=vqae_nb_epochs,
- logger=logger,
- device=device,
- device_storage=device_storage,
- )
-
- train_frame_seq = self.frame2seq(train_frames).to(device_storage)
- test_frame_seq = self.frame2seq(test_frames).to(device_storage)
-
- nb_frame_codes = max(train_frame_seq.max(), test_frame_seq.max()) + 1
- nb_action_codes = max(train_action_seq.max(), test_action_seq.max()) + 1
-
- self.len_frame_seq = train_frame_seq.size(1)
- self.len_action_seq = train_action_seq.size(1)
- self.nb_codes = nb_frame_codes + nb_action_codes
-
- train_frame_seq = train_frame_seq.reshape(train_frame_seq.size(0) // 2, 2, -1)