From 8e23dd068df00df61c690ffa89ecc8cb9db4b32d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 15 Jul 2023 10:21:09 +0200 Subject: [PATCH] Update. --- main.py | 6 ++++++ tasks.py | 47 ++++++++++++++++++++++++++++++++++++++++++----- world.py | 13 ++++++++++--- 3 files changed, 58 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index c763016..58e8046 100755 --- a/main.py +++ b/main.py @@ -133,6 +133,11 @@ parser.add_argument("--expr_result_max", type=int, default=99) parser.add_argument("--expr_input_file", type=str, default=None) +############################## +# World options + +parser.add_argument("--world_vqae_nb_epochs", type=int, default=10) + ###################################################################### args = parser.parse_args() @@ -328,6 +333,7 @@ elif args.task == "world": nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, + vqae_nb_epochs=args.world_vqae_nb_epochs, device=device, ) diff --git a/tasks.py b/tasks.py index 15d97b8..96d0621 100755 --- a/tasks.py +++ b/tasks.py @@ -20,6 +20,8 @@ def masked_inplace_autoregression( progress_bar_desc="autoregression", device=torch.device("cpu"), ): + assert input.size() == ar_mask.size() + batches = zip(input.split(batch_size), ar_mask.split(batch_size)) if progress_bar_desc is not None: @@ -27,7 +29,7 @@ def masked_inplace_autoregression( batches, dynamic_ncols=True, desc=progress_bar_desc, - total=input.size(0) // batch_size, + #total=input.size(0) // batch_size, ) with torch.autograd.no_grad(): @@ -944,6 +946,7 @@ class Expr(Task): ###################################################################### + import world @@ -953,15 +956,16 @@ class World(Task): nb_train_samples, nb_test_samples, batch_size, + vqae_nb_epochs, device=torch.device("cpu"), ): self.batch_size = batch_size self.device = device ( - self.train_input, + train_frames, self.train_actions, - self.test_input, + test_frames, self.test_actions, self.frame2seq, self.seq2frame, @@ -970,9 +974,15 @@ class World(Task): nb_test_samples, mode="first_last", nb_steps=30, - nb_epochs=2, + nb_epochs=vqae_nb_epochs, + device=device, ) + self.train_input = self.frame2seq(train_frames) + self.train_input = self.train_input.reshape(self.train_input.size(0) // 2, -1) + self.test_input = self.frame2seq(test_frames) + self.test_input = self.test_input.reshape(self.test_input.size(0) // 2, -1) + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 def batches(self, split="train", nb_to_use=-1, desc=None): @@ -993,7 +1003,34 @@ class World(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): - pass + l = self.train_input.size(1) + k = torch.arange(l, device=self.device)[None, :] + result = self.test_input[:64].clone() + + ar_mask = (k >= l // 2).long().expand_as(result) + result *= 1 - ar_mask + + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + device=self.device, + ) + + result = result.reshape(result.size(0) * 2, -1) + + frames = self.seq2frame(result) + image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png") + torchvision.utils.save_image( + frames.float() / (world.Box.nb_rgb_levels - 1), + image_name, + nrow=8, + padding=1, + pad_value=0.0, + ) + logger(f"wrote {image_name}") ###################################################################### diff --git a/world.py b/world.py index c3eb101..5c21fad 100755 --- a/world.py +++ b/world.py @@ -65,7 +65,7 @@ class SignSTE(nn.Module): def train_encoder( train_input, test_input, - depth=3, + depth=2, dim_hidden=48, nb_bits_per_token=8, lr_start=1e-3, @@ -331,7 +331,12 @@ def generate_episodes(nb, steps): def create_data_and_processors( - nb_train_samples, nb_test_samples, mode, nb_steps, nb_epochs=10 + nb_train_samples, + nb_test_samples, + mode, + nb_steps, + nb_epochs=10, + device=torch.device("cpu"), ): assert mode in ["first_last"] @@ -339,10 +344,12 @@ def create_data_and_processors( steps = [True] + [False] * (nb_steps + 1) + [True] train_input, train_actions = generate_episodes(nb_train_samples, steps) + train_input, train_actions = train_input.to(device), train_actions.to(device) test_input, test_actions = generate_episodes(nb_test_samples, steps) + test_input, test_actions = test_input.to(device), test_actions.to(device) encoder, quantizer, decoder = train_encoder( - train_input, test_input, nb_epochs=nb_epochs + train_input, test_input, nb_epochs=nb_epochs, device=device ) encoder.train(False) quantizer.train(False) -- 2.39.5