From 2ac9d1299a84f96228f49fbdac02d5a7017445e5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 16 Jul 2023 15:31:38 +0200 Subject: [PATCH] Update. --- tasks.py | 11 +++++++---- world.py | 21 +++++++++++++-------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/tasks.py b/tasks.py index df3fd81..f8fb9b9 100755 --- a/tasks.py +++ b/tasks.py @@ -959,6 +959,7 @@ class World(Task): vqae_nb_epochs, logger=None, device=torch.device("cpu"), + device_storage=torch.device("cpu"), ): self.batch_size = batch_size self.device = device @@ -978,12 +979,13 @@ class World(Task): nb_epochs=vqae_nb_epochs, logger=logger, device=device, + device_storage=device_storage, ) print(f"{train_action_seq.size()=}") - train_frame_seq = self.frame2seq(train_frames) - test_frame_seq = self.frame2seq(test_frames) + train_frame_seq = self.frame2seq(train_frames).to(device_storage) + test_frame_seq = self.frame2seq(test_frames).to(device_storage) nb_frame_codes = max(train_frame_seq.max(), test_frame_seq.max()) + 1 nb_action_codes = max(train_action_seq.max(), test_action_seq.max()) + 1 @@ -993,6 +995,7 @@ class World(Task): self.nb_codes = nb_frame_codes + nb_action_codes train_frame_seq = train_frame_seq.reshape(train_frame_seq.size(0) // 2, 2, -1) + print(f"{train_action_seq.device=} {nb_frame_codes.device=}") train_action_seq += nb_frame_codes self.train_input = torch.cat( (train_frame_seq[:, 0, :], train_action_seq, train_frame_seq[:, 1, :]), 1 @@ -1014,7 +1017,7 @@ class World(Task): for batch in tqdm.tqdm( input.split(self.batch_size), dynamic_ncols=True, desc=desc ): - yield batch + yield batch.to(self.device) def vocabulary_size(self): return self.nb_codes @@ -1026,7 +1029,7 @@ class World(Task): 2 * self.len_frame_seq + self.len_action_seq, device=self.device )[None, :] - input = self.test_input[:64] + input = self.test_input[:64].to(self.device) result = input.clone() ar_mask = ( diff --git a/world.py b/world.py index fb8609d..fa305cf 100755 --- a/world.py +++ b/world.py @@ -149,6 +149,7 @@ def train_encoder( acc_train_loss = 0.0 for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"): + input = input.to(device) z = encoder(input) zq = z if k < 2 else quantizer(z) output = decoder(zq) @@ -168,6 +169,7 @@ def train_encoder( acc_test_loss = 0.0 for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"): + input = input.to(device) z = encoder(input) zq = z if k < 1 else quantizer(z) output = decoder(zq) @@ -341,6 +343,7 @@ def create_data_and_processors( nb_steps, nb_epochs=10, device=torch.device("cpu"), + device_storage=torch.device("cpu"), logger=None, ): assert mode in ["first_last"] @@ -349,9 +352,9 @@ def create_data_and_processors( steps = [True] + [False] * (nb_steps + 1) + [True] train_input, train_actions = generate_episodes(nb_train_samples, steps) - train_input, train_actions = train_input.to(device), train_actions.to(device) + train_input, train_actions = train_input.to(device_storage), train_actions.to(device_storage) test_input, test_actions = generate_episodes(nb_test_samples, steps) - test_input, test_actions = test_input.to(device), test_actions.to(device) + test_input, test_actions = test_input.to(device_storage), test_actions.to(device_storage) encoder, quantizer, decoder = train_encoder( train_input, test_input, nb_epochs=nb_epochs, logger=logger, device=device @@ -360,21 +363,22 @@ def create_data_and_processors( quantizer.train(False) decoder.train(False) - z = encoder(train_input[:1]) - pow2 = (2 ** torch.arange(z.size(1), device=z.device))[None, None, :] + z = encoder(train_input[:1].to(device)) + pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :] z_h, z_w = z.size(2), z.size(3) def frame2seq(input, batch_size=25): seq = [] - + p = pow2.to(device) for x in input.split(batch_size): + x=x.to(device) z = encoder(x) ze_bool = (quantizer(z) >= 0).long() output = ( ze_bool.permute(0, 2, 3, 1).reshape( ze_bool.size(0), -1, ze_bool.size(1) ) - * pow2 + * p ).sum(-1) seq.append(output) @@ -383,9 +387,10 @@ def create_data_and_processors( def seq2frame(input, batch_size=25, T=1e-2): frames = [] - + p = pow2.to(device) for seq in input.split(batch_size): - zd_bool = (seq[:, :, None] // pow2) % 2 + seq = seq.to(device) + zd_bool = (seq[:, :, None] // p) % 2 zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2) logits = decoder(zd_bool * 2.0 - 1.0) logits = logits.reshape( -- 2.20.1