Update.
[culture.git] / tasks.py
index df3fd81..f8fb9b9 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -959,6 +959,7 @@ class World(Task):
         vqae_nb_epochs,
         logger=None,
         device=torch.device("cpu"),
+        device_storage=torch.device("cpu"),
     ):
         self.batch_size = batch_size
         self.device = device
@@ -978,12 +979,13 @@ class World(Task):
             nb_epochs=vqae_nb_epochs,
             logger=logger,
             device=device,
+            device_storage=device_storage,
         )
 
         print(f"{train_action_seq.size()=}")
 
-        train_frame_seq = self.frame2seq(train_frames)
-        test_frame_seq = self.frame2seq(test_frames)
+        train_frame_seq = self.frame2seq(train_frames).to(device_storage)
+        test_frame_seq = self.frame2seq(test_frames).to(device_storage)
 
         nb_frame_codes = max(train_frame_seq.max(), test_frame_seq.max()) + 1
         nb_action_codes = max(train_action_seq.max(), test_action_seq.max()) + 1
@@ -993,6 +995,7 @@ class World(Task):
         self.nb_codes = nb_frame_codes + nb_action_codes
 
         train_frame_seq = train_frame_seq.reshape(train_frame_seq.size(0) // 2, 2, -1)
+        print(f"{train_action_seq.device=} {nb_frame_codes.device=}")
         train_action_seq += nb_frame_codes
         self.train_input = torch.cat(
             (train_frame_seq[:, 0, :], train_action_seq, train_frame_seq[:, 1, :]), 1
@@ -1014,7 +1017,7 @@ class World(Task):
         for batch in tqdm.tqdm(
             input.split(self.batch_size), dynamic_ncols=True, desc=desc
         ):
-            yield batch
+            yield batch.to(self.device)
 
     def vocabulary_size(self):
         return self.nb_codes
@@ -1026,7 +1029,7 @@ class World(Task):
             2 * self.len_frame_seq + self.len_action_seq, device=self.device
         )[None, :]
 
-        input = self.test_input[:64]
+        input = self.test_input[:64].to(self.device)
         result = input.clone()
 
         ar_mask = (