X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=f8fb9b93ace534d6a225558f82b7d2d61211031a;hb=2ac9d1299a84f96228f49fbdac02d5a7017445e5;hp=df3fd81e516cc7c8080e55fddd01ef0401f1a55a;hpb=2192d72289bbf2cd069f67d3e93daf7934f886af;p=culture.git diff --git a/tasks.py b/tasks.py index df3fd81..f8fb9b9 100755 --- a/tasks.py +++ b/tasks.py @@ -959,6 +959,7 @@ class World(Task): vqae_nb_epochs, logger=None, device=torch.device("cpu"), + device_storage=torch.device("cpu"), ): self.batch_size = batch_size self.device = device @@ -978,12 +979,13 @@ class World(Task): nb_epochs=vqae_nb_epochs, logger=logger, device=device, + device_storage=device_storage, ) print(f"{train_action_seq.size()=}") - train_frame_seq = self.frame2seq(train_frames) - test_frame_seq = self.frame2seq(test_frames) + train_frame_seq = self.frame2seq(train_frames).to(device_storage) + test_frame_seq = self.frame2seq(test_frames).to(device_storage) nb_frame_codes = max(train_frame_seq.max(), test_frame_seq.max()) + 1 nb_action_codes = max(train_action_seq.max(), test_action_seq.max()) + 1 @@ -993,6 +995,7 @@ class World(Task): self.nb_codes = nb_frame_codes + nb_action_codes train_frame_seq = train_frame_seq.reshape(train_frame_seq.size(0) // 2, 2, -1) + print(f"{train_action_seq.device=} {nb_frame_codes.device=}") train_action_seq += nb_frame_codes self.train_input = torch.cat( (train_frame_seq[:, 0, :], train_action_seq, train_frame_seq[:, 1, :]), 1 @@ -1014,7 +1017,7 @@ class World(Task): for batch in tqdm.tqdm( input.split(self.batch_size), dynamic_ncols=True, desc=desc ): - yield batch + yield batch.to(self.device) def vocabulary_size(self): return self.nb_codes @@ -1026,7 +1029,7 @@ class World(Task): 2 * self.len_frame_seq + self.len_action_seq, device=self.device )[None, :] - input = self.test_input[:64] + input = self.test_input[:64].to(self.device) result = input.clone() ar_mask = (