vqae_nb_epochs,
logger=None,
device=torch.device("cpu"),
+ device_storage=torch.device("cpu"),
):
self.batch_size = batch_size
self.device = device
nb_epochs=vqae_nb_epochs,
logger=logger,
device=device,
+ device_storage=device_storage,
)
print(f"{train_action_seq.size()=}")
- train_frame_seq = self.frame2seq(train_frames)
- test_frame_seq = self.frame2seq(test_frames)
+ train_frame_seq = self.frame2seq(train_frames).to(device_storage)
+ test_frame_seq = self.frame2seq(test_frames).to(device_storage)
nb_frame_codes = max(train_frame_seq.max(), test_frame_seq.max()) + 1
nb_action_codes = max(train_action_seq.max(), test_action_seq.max()) + 1
self.nb_codes = nb_frame_codes + nb_action_codes
train_frame_seq = train_frame_seq.reshape(train_frame_seq.size(0) // 2, 2, -1)
+ print(f"{train_action_seq.device=} {nb_frame_codes.device=}")
train_action_seq += nb_frame_codes
self.train_input = torch.cat(
(train_frame_seq[:, 0, :], train_action_seq, train_frame_seq[:, 1, :]), 1
for batch in tqdm.tqdm(
input.split(self.batch_size), dynamic_ncols=True, desc=desc
):
- yield batch
+ yield batch.to(self.device)
def vocabulary_size(self):
return self.nb_codes
2 * self.len_frame_seq + self.len_action_seq, device=self.device
)[None, :]
- input = self.test_input[:64]
+ input = self.test_input[:64].to(self.device)
result = input.clone()
ar_mask = (