From 8adf0586ee5aeb9fbdf81b78c7ff4b484a9b82ab Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jun 2024 08:50:07 +0200 Subject: [PATCH] Update. --- main.py | 108 ++++++++++++++++++++++++------------------------ tasks.py | 123 +++++++++++++++++++++++++++---------------------------- world.py | 2 +- 3 files changed, 115 insertions(+), 118 deletions(-) diff --git a/main.py b/main.py index 61d77ed..11d712a 100755 --- a/main.py +++ b/main.py @@ -668,28 +668,6 @@ vocabulary_size = task.vocabulary_size() log_string(f"vocabulary_size {vocabulary_size}") -############################## - -models = [] - -for k in range(2): - models.append( - mygpt.MyGPT( - vocabulary_size=vocabulary_size, - dim_model=args.dim_model, - dim_keys=args.dim_keys, - dim_hidden=args.dim_hidden, - nb_heads=args.nb_heads, - nb_blocks=args.nb_blocks, - causal=True, - dropout=args.dropout, - ).to(device) - ) - - -nb_parameters = sum(p.numel() for p in models[0].parameters()) -log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") - ###################################################################### # Compute the entropy of the training tokens @@ -763,20 +741,16 @@ else: log_string(f"learning_rate_schedule {learning_rate_schedule}") -time_pred_result = None - ###################################################################### -def one_epoch(model, task, learning_rate): - log_string(f"learning_rate {learning_rate}") - +def one_epoch(model, task): if args.optim == "sgd": - optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) elif args.optim == "adam": - optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) elif args.optim == "adamw": - optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) else: raise ValueError(f"Unknown optimizer {args.optim}.") @@ -840,7 +814,7 @@ def run_tests(model, task, deterministic_synthesis): log_string(f"test_perplexity {n_epoch} {test_perplexity}") - return main_test_accuracy + model.main_test_accuracy = main_test_accuracy ###################################################################### @@ -852,9 +826,6 @@ def create_quizzes( task, nb_for_train=1000, nb_for_test=100, - nb_runs=10, - nb_min_correct=9, - nb_max_correct=9, ): kept = [] @@ -866,14 +837,9 @@ def create_quizzes( nb=4 * (nb_for_train + nb_for_test), model=model, other_models=other_models, - nb_runs=nb_runs, ) - to_keep = new_quizzes[ - torch.logical_and( - nb_correct >= nb_min_correct, nb_correct <= nb_max_correct - ) - ] + to_keep = new_quizzes[nb_correct == len(other_models) - 1] log_string(f"keep {to_keep.size(0)} quizzes") kept.append(to_keep) @@ -890,29 +856,63 @@ def create_quizzes( ) +###################################################################### + +models = [] + +for k in range(5): + model = mygpt.MyGPT( + vocabulary_size=vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + causal=True, + dropout=args.dropout, + ).to(device) + + model.main_test_accuracy = 0.0 + model.id = k + + models.append(model) + + +nb_parameters = sum(p.numel() for p in models[0].parameters()) +log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") + ###################################################################### accuracy_to_make_quizzes = 0.975 for n_epoch in range(args.nb_epochs): - learning_rate = learning_rate_schedule[n_epoch] + models.sort(key=lambda model: model.main_test_accuracy) - for m in models: - one_epoch(m, task, learning_rate) - test_accuracy = run_tests(m, task, deterministic_synthesis=False) + model = models[0] - if test_accuracy >= accuracy_to_make_quizzes: - other_models = models.copy() - other_models.remove(m) - create_quizzes(m, other_models, task) + log_string( + f"training model {model.id} main_test_accuracy {model.main_test_accuracy}" + ) + + one_epoch(model, task) - # -------------------------------------------- + log_string( + f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}" + ) - time_current_result = datetime.datetime.now() - if time_pred_result is not None: - log_string( - f"next_result {time_current_result + (time_current_result - time_pred_result)}" + run_tests(model, task, deterministic_synthesis=False) + + if model.main_test_accuracy >= accuracy_to_make_quizzes: + other_models = models.copy() + other_models.remove(model) + + create_quizzes( + model, + other_models, + task, + nb_for_train=1000, + nb_for_test=100, ) - time_pred_result = time_current_result + ###################################################################### diff --git a/tasks.py b/tasks.py index 0345bd0..b4829d9 100755 --- a/tasks.py +++ b/tasks.py @@ -2105,6 +2105,10 @@ class World(Task): torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) logger(f"wrote {image_name}") + def make_ar_mask(self, input): + b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2 + return b.long()[None, :].expand_as(input) + def __init__( self, nb_train_samples, @@ -2123,41 +2127,48 @@ class World(Task): self.train_input = world.generate( nb_train_samples, height=self.height, width=self.width - ) - self.train_ar_mask = ( - (torch.arange(self.train_input.size(1)) > self.train_input.size(1) // 2) - .long()[None, :] - .expand_as(self.train_input) - ) + ).to(device) self.test_input = world.generate( nb_test_samples, height=self.height, width=self.width - ) - self.test_ar_mask = ( - (torch.arange(self.test_input.size(1)) > self.test_input.size(1) // 2) - .long()[None, :] - .expand_as(self.test_input) - ) - - self.train_input, self.train_ar_mask = self.train_input.to( - device - ), self.train_ar_mask.to(device) - self.test_input, self.test_ar_mask = self.test_input.to( - device - ), self.test_ar_mask.to(device) + ).to(device) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + self.train_quizzes = [] + self.test_quizzes = [] + if result_dir is not None: self.save_image( self.train_input[:96], result_dir, f"world_train.png", logger ) - def batches(self, split="train", nb_to_use=-1, desc=None): + def batches(self, split="train", desc=None): assert split in {"train", "test"} - input = self.train_input if split == "train" else self.test_input - if nb_to_use > 0: - input = input[:nb_to_use] + if split == "train": + input = self.train_input + quizzes = self.train_quizzes + else: + input = self.test_input + quizzes = self.test_quizzes + + if len(quizzes) > 0: + quizzes = torch.cat(quizzes, dim=0) + if quizzes.size(0) > input.size(0) // 2: + i = torch.randperm(input.size(0))[: input.size(0) // 2] + quizzes = quizzes[i] + + i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)] + input = input[i] + + self.nb_batch_samples_world = input.size(0) + self.nb_batch_samples_quizzes = quizzes.size(0) + + input = torch.cat([input, quizzes], dim=0) + else: + self.nb_batch_samples_world = input.size(0) + self.nb_batch_samples_quizzes = 0 + if desc is None: desc = f"epoch-{split}" for batch in tqdm.tqdm( @@ -2171,8 +2182,9 @@ class World(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000 ): - def compute_accuracy(input, ar_mask, logger=None): - input, ar_mask = input[:nmax], ar_mask[:nmax] + def compute_accuracy(input, logger=None): + input = input[:nmax] + ar_mask = self.make_ar_mask(input) result = input.clone() * (1 - ar_mask) masked_inplace_autoregression( @@ -2192,17 +2204,13 @@ class World(Task): return nb_total, nb_correct - train_nb_total, train_nb_correct = compute_accuracy( - self.train_input, self.train_ar_mask - ) + train_nb_total, train_nb_correct = compute_accuracy(self.train_input) logger( f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" ) - test_nb_total, test_nb_correct = compute_accuracy( - self.test_input, self.test_ar_mask, logger - ) + test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger) logger( f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" @@ -2213,7 +2221,8 @@ class World(Task): ############################## - input, ar_mask = self.test_input[:96], self.test_ar_mask[:96] + input = self.test_input[:96] + ar_mask = self.make_ar_mask(input) result = input.clone() * (1 - ar_mask) masked_inplace_autoregression( @@ -2233,19 +2242,19 @@ class World(Task): return main_test_accuracy def store_new_quizzes(self, new_quizzes, for_train=True): - input = self.train_input if for_train else self.test_input - - nb_current = input.size(0) - nb_new = new_quizzes.size(0) - if nb_new >= nb_current: - input[...] = new_quizzes[:nb_current] + if for_train: + self.train_quizzes.append(new_quizzes) else: - nb_kept = nb_current - nb_new - input[:nb_kept] = input[-nb_kept:].clone() - input[nb_kept:] = new_quizzes + self.test_quizzes.append(new_quizzes) def create_new_quizzes( - self, n_epoch, result_dir, logger, nb, models, other_models, nb_runs + self, + n_epoch, + result_dir, + logger, + nb, + model, + other_models, ): new_quizzes = torch.empty( nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 @@ -2262,35 +2271,23 @@ class World(Task): device=self.device, ) - input = ( - new_quizzes[:, None, :] - .expand(-1, nb_runs, -1) - .clone() - .reshape(-1, new_quizzes.size(-1)) - ) - result = input.clone() + ar_mask = self.make_ar_mask(new_quizzes) - ar_mask = ( - (torch.arange(result.size(1), device=self.device) > result.size(1) // 2) - .long()[None, :] - .expand_as(result) - ) + nb_correct = 0 - dispatch = torch.randint(len(other_models), (result.size(0),)) + for m in other_models: + result = new_quizzes.clone() - for n, m in enumerate(other_models): masked_inplace_autoregression( m, self.batch_size, - result[dispatch == n], - ar_mask[dispatch == n], - deterministic_synthesis=False, - progress_bar_desc=None, + result, + ar_mask, + deterministic_synthesis=True, + progress_bar_desc="solving quizzes", device=self.device, ) - nb_correct = ( - (input == result).long().min(dim=-1).values.reshape(-1, nb_runs).sum(dim=-1) - ) + nb_correct += (new_quizzes == result).long().min(dim=-1).values return new_quizzes, nb_correct diff --git a/world.py b/world.py index 89833e6..118a470 100755 --- a/world.py +++ b/world.py @@ -41,7 +41,7 @@ def generate( f_end = torch.zeros(nb, height, width, dtype=torch.int64) n = torch.arange(f_start.size(0)) - for n in range(nb): + for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"): nb_fish = torch.randint(max_nb_obj, (1,)).item() + 1 for c in torch.randperm(colors.size(0) - 2)[:nb_fish].sort().values: i, j = ( -- 2.20.1