From: François Fleuret Date: Fri, 14 Jul 2023 17:32:44 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=bf48dc69f7f57ad391481c8917570e35f661cc4a;p=culture.git Update. --- diff --git a/main.py b/main.py index 80f2733..c763016 100755 --- a/main.py +++ b/main.py @@ -35,7 +35,7 @@ parser.add_argument( "--task", type=str, default="picoclvr", - help="picoclvr, mnist, maze, snake, stack, expr", + help="picoclvr, mnist, maze, snake, stack, expr, world", ) parser.add_argument("--log_filename", type=str, default="train.log", help=" ") @@ -110,7 +110,7 @@ parser.add_argument("--snake_nb_colors", type=int, default=5) parser.add_argument("--snake_length", type=int, default=200) ############################## -# Snake options +# Stack options parser.add_argument("--stack_nb_steps", type=int, default=100) @@ -181,6 +181,12 @@ default_args = { "nb_train_samples": 1000000, "nb_test_samples": 10000, }, + "world": { + "nb_epochs": 5, + "batch_size": 25, + "nb_train_samples": 10000, + "nb_test_samples": 1000, + }, } if args.task in default_args: @@ -317,6 +323,14 @@ elif args.task == "expr": device=device, ) +elif args.task == "world": + task = tasks.World( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + device=device, + ) + else: raise ValueError(f"Unknown task {args.task}") diff --git a/tasks.py b/tasks.py index 75781ab..15d97b8 100755 --- a/tasks.py +++ b/tasks.py @@ -590,8 +590,6 @@ class Snake(Task): ) result *= 1 - ar_mask - # snake.solver(result,ar_mask) - masked_inplace_autoregression( model, self.batch_size, @@ -605,19 +603,8 @@ class Snake(Task): nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum() - # nb_total = result.size(0) - # nb_correct = ((result - input).abs().sum(1) == 0).sum() - return nb_total, nb_correct - # train_nb_total, train_nb_correct = compute_nb_correct( - # self.train_input, self.train_prior_visits - # ) - - # logger( - # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" - # ) - test_nb_total, test_nb_correct = compute_nb_correct( self.test_input[:1000], self.test_prior_visits[:1000] ) @@ -956,4 +943,57 @@ class Expr(Task): ############################################################## +###################################################################### +import world + + +class World(Task): + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + device=torch.device("cpu"), + ): + self.batch_size = batch_size + self.device = device + + ( + self.train_input, + self.train_actions, + self.test_input, + self.test_actions, + self.frame2seq, + self.seq2frame, + ) = world.create_data_and_processors( + nb_train_samples, + nb_test_samples, + mode="first_last", + nb_steps=30, + nb_epochs=2, + ) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + pass + + ###################################################################### diff --git a/world.py b/world.py index c33a584..c3eb101 100755 --- a/world.py +++ b/world.py @@ -322,17 +322,24 @@ def generate_episode(steps, size=64): def generate_episodes(nb, steps): - all_frames = [] + all_frames, all_actions = [], [] for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"): frames, actions = generate_episode(steps) all_frames += frames - return torch.cat(all_frames, 0).contiguous() + all_actions += [actions] + return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0) -def create_data_and_processors(nb_train_samples, nb_test_samples, nb_epochs=10): - steps = [True] + [False] * 30 + [True] - train_input = generate_episodes(nb_train_samples, steps) - test_input = generate_episodes(nb_test_samples, steps) +def create_data_and_processors( + nb_train_samples, nb_test_samples, mode, nb_steps, nb_epochs=10 +): + assert mode in ["first_last"] + + if mode == "first_last": + steps = [True] + [False] * (nb_steps + 1) + [True] + + train_input, train_actions = generate_episodes(nb_train_samples, steps) + test_input, test_actions = generate_episodes(nb_test_samples, steps) encoder, quantizer, decoder = train_encoder( train_input, test_input, nb_epochs=nb_epochs @@ -380,15 +387,26 @@ def create_data_and_processors(nb_train_samples, nb_test_samples, nb_epochs=10): return torch.cat(frames, dim=0) - return train_input, test_input, frame2seq, seq2frame + return train_input, train_actions, test_input, test_actions, frame2seq, seq2frame ###################################################################### if __name__ == "__main__": - train_input, test_input, frame2seq, seq2frame = create_data_and_processors( + ( + train_input, + train_actions, + test_input, + test_actions, + frame2seq, + seq2frame, + ) = create_data_and_processors( # 10000, 1000, - 100, 100, nb_epochs=2, + 100, + 100, + nb_epochs=2, + mode="first_last", + nb_steps=20, ) input = test_input[:64]