Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 06:50:07 +0000 (08:50 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 06:50:07 +0000 (08:50 +0200)
main.py
tasks.py
world.py

diff --git a/main.py b/main.py
index 61d77ed..11d712a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -668,28 +668,6 @@ vocabulary_size = task.vocabulary_size()
 
 log_string(f"vocabulary_size {vocabulary_size}")
 
-##############################
-
-models = []
-
-for k in range(2):
-    models.append(
-        mygpt.MyGPT(
-            vocabulary_size=vocabulary_size,
-            dim_model=args.dim_model,
-            dim_keys=args.dim_keys,
-            dim_hidden=args.dim_hidden,
-            nb_heads=args.nb_heads,
-            nb_blocks=args.nb_blocks,
-            causal=True,
-            dropout=args.dropout,
-        ).to(device)
-    )
-
-
-nb_parameters = sum(p.numel() for p in models[0].parameters())
-log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
-
 ######################################################################
 
 # Compute the entropy of the training tokens
@@ -763,20 +741,16 @@ else:
 
 log_string(f"learning_rate_schedule {learning_rate_schedule}")
 
-time_pred_result = None
-
 ######################################################################
 
 
-def one_epoch(model, task, learning_rate):
-    log_string(f"learning_rate {learning_rate}")
-
+def one_epoch(model, task):
     if args.optim == "sgd":
-        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
+        optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
     elif args.optim == "adam":
-        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
+        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
     elif args.optim == "adamw":
-        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
+        optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
     else:
         raise ValueError(f"Unknown optimizer {args.optim}.")
 
@@ -840,7 +814,7 @@ def run_tests(model, task, deterministic_synthesis):
 
         log_string(f"test_perplexity {n_epoch} {test_perplexity}")
 
-    return main_test_accuracy
+    model.main_test_accuracy = main_test_accuracy
 
 
 ######################################################################
@@ -852,9 +826,6 @@ def create_quizzes(
     task,
     nb_for_train=1000,
     nb_for_test=100,
-    nb_runs=10,
-    nb_min_correct=9,
-    nb_max_correct=9,
 ):
     kept = []
 
@@ -866,14 +837,9 @@ def create_quizzes(
             nb=4 * (nb_for_train + nb_for_test),
             model=model,
             other_models=other_models,
-            nb_runs=nb_runs,
         )
 
-        to_keep = new_quizzes[
-            torch.logical_and(
-                nb_correct >= nb_min_correct, nb_correct <= nb_max_correct
-            )
-        ]
+        to_keep = new_quizzes[nb_correct == len(other_models) - 1]
         log_string(f"keep {to_keep.size(0)} quizzes")
         kept.append(to_keep)
 
@@ -890,29 +856,63 @@ def create_quizzes(
     )
 
 
+######################################################################
+
+models = []
+
+for k in range(5):
+    model = mygpt.MyGPT(
+        vocabulary_size=vocabulary_size,
+        dim_model=args.dim_model,
+        dim_keys=args.dim_keys,
+        dim_hidden=args.dim_hidden,
+        nb_heads=args.nb_heads,
+        nb_blocks=args.nb_blocks,
+        causal=True,
+        dropout=args.dropout,
+    ).to(device)
+
+    model.main_test_accuracy = 0.0
+    model.id = k
+
+    models.append(model)
+
+
+nb_parameters = sum(p.numel() for p in models[0].parameters())
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+
 ######################################################################
 
 accuracy_to_make_quizzes = 0.975
 
 for n_epoch in range(args.nb_epochs):
-    learning_rate = learning_rate_schedule[n_epoch]
+    models.sort(key=lambda model: model.main_test_accuracy)
 
-    for m in models:
-        one_epoch(m, task, learning_rate)
-        test_accuracy = run_tests(m, task, deterministic_synthesis=False)
+    model = models[0]
 
-        if test_accuracy >= accuracy_to_make_quizzes:
-            other_models = models.copy()
-            other_models.remove(m)
-            create_quizzes(m, other_models, task)
+    log_string(
+        f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
+    )
+
+    one_epoch(model, task)
 
-    # --------------------------------------------
+    log_string(
+        f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
+    )
 
-    time_current_result = datetime.datetime.now()
-    if time_pred_result is not None:
-        log_string(
-            f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+    run_tests(model, task, deterministic_synthesis=False)
+
+    if model.main_test_accuracy >= accuracy_to_make_quizzes:
+        other_models = models.copy()
+        other_models.remove(model)
+
+        create_quizzes(
+            model,
+            other_models,
+            task,
+            nb_for_train=1000,
+            nb_for_test=100,
         )
-    time_pred_result = time_current_result
+
 
 ######################################################################
index 0345bd0..b4829d9 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -2105,6 +2105,10 @@ class World(Task):
         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
         logger(f"wrote {image_name}")
 
+    def make_ar_mask(self, input):
+        b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
+        return b.long()[None, :].expand_as(input)
+
     def __init__(
         self,
         nb_train_samples,
@@ -2123,41 +2127,48 @@ class World(Task):
 
         self.train_input = world.generate(
             nb_train_samples, height=self.height, width=self.width
-        )
-        self.train_ar_mask = (
-            (torch.arange(self.train_input.size(1)) > self.train_input.size(1) // 2)
-            .long()[None, :]
-            .expand_as(self.train_input)
-        )
+        ).to(device)
 
         self.test_input = world.generate(
             nb_test_samples, height=self.height, width=self.width
-        )
-        self.test_ar_mask = (
-            (torch.arange(self.test_input.size(1)) > self.test_input.size(1) // 2)
-            .long()[None, :]
-            .expand_as(self.test_input)
-        )
-
-        self.train_input, self.train_ar_mask = self.train_input.to(
-            device
-        ), self.train_ar_mask.to(device)
-        self.test_input, self.test_ar_mask = self.test_input.to(
-            device
-        ), self.test_ar_mask.to(device)
+        ).to(device)
 
         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
 
+        self.train_quizzes = []
+        self.test_quizzes = []
+
         if result_dir is not None:
             self.save_image(
                 self.train_input[:96], result_dir, f"world_train.png", logger
             )
 
-    def batches(self, split="train", nb_to_use=-1, desc=None):
+    def batches(self, split="train", desc=None):
         assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        if nb_to_use > 0:
-            input = input[:nb_to_use]
+        if split == "train":
+            input = self.train_input
+            quizzes = self.train_quizzes
+        else:
+            input = self.test_input
+            quizzes = self.test_quizzes
+
+        if len(quizzes) > 0:
+            quizzes = torch.cat(quizzes, dim=0)
+            if quizzes.size(0) > input.size(0) // 2:
+                i = torch.randperm(input.size(0))[: input.size(0) // 2]
+                quizzes = quizzes[i]
+
+            i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
+            input = input[i]
+
+            self.nb_batch_samples_world = input.size(0)
+            self.nb_batch_samples_quizzes = quizzes.size(0)
+
+            input = torch.cat([input, quizzes], dim=0)
+        else:
+            self.nb_batch_samples_world = input.size(0)
+            self.nb_batch_samples_quizzes = 0
+
         if desc is None:
             desc = f"epoch-{split}"
         for batch in tqdm.tqdm(
@@ -2171,8 +2182,9 @@ class World(Task):
     def produce_results(
         self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
     ):
-        def compute_accuracy(input, ar_mask, logger=None):
-            input, ar_mask = input[:nmax], ar_mask[:nmax]
+        def compute_accuracy(input, logger=None):
+            input = input[:nmax]
+            ar_mask = self.make_ar_mask(input)
             result = input.clone() * (1 - ar_mask)
 
             masked_inplace_autoregression(
@@ -2192,17 +2204,13 @@ class World(Task):
 
             return nb_total, nb_correct
 
-        train_nb_total, train_nb_correct = compute_accuracy(
-            self.train_input, self.train_ar_mask
-        )
+        train_nb_total, train_nb_correct = compute_accuracy(self.train_input)
 
         logger(
             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
         )
 
-        test_nb_total, test_nb_correct = compute_accuracy(
-            self.test_input, self.test_ar_mask, logger
-        )
+        test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger)
 
         logger(
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
@@ -2213,7 +2221,8 @@ class World(Task):
 
         ##############################
 
-        input, ar_mask = self.test_input[:96], self.test_ar_mask[:96]
+        input = self.test_input[:96]
+        ar_mask = self.make_ar_mask(input)
         result = input.clone() * (1 - ar_mask)
 
         masked_inplace_autoregression(
@@ -2233,19 +2242,19 @@ class World(Task):
         return main_test_accuracy
 
     def store_new_quizzes(self, new_quizzes, for_train=True):
-        input = self.train_input if for_train else self.test_input
-
-        nb_current = input.size(0)
-        nb_new = new_quizzes.size(0)
-        if nb_new >= nb_current:
-            input[...] = new_quizzes[:nb_current]
+        if for_train:
+            self.train_quizzes.append(new_quizzes)
         else:
-            nb_kept = nb_current - nb_new
-            input[:nb_kept] = input[-nb_kept:].clone()
-            input[nb_kept:] = new_quizzes
+            self.test_quizzes.append(new_quizzes)
 
     def create_new_quizzes(
-        self, n_epoch, result_dir, logger, nb, models, other_models, nb_runs
+        self,
+        n_epoch,
+        result_dir,
+        logger,
+        nb,
+        model,
+        other_models,
     ):
         new_quizzes = torch.empty(
             nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
@@ -2262,35 +2271,23 @@ class World(Task):
             device=self.device,
         )
 
-        input = (
-            new_quizzes[:, None, :]
-            .expand(-1, nb_runs, -1)
-            .clone()
-            .reshape(-1, new_quizzes.size(-1))
-        )
-        result = input.clone()
+        ar_mask = self.make_ar_mask(new_quizzes)
 
-        ar_mask = (
-            (torch.arange(result.size(1), device=self.device) > result.size(1) // 2)
-            .long()[None, :]
-            .expand_as(result)
-        )
+        nb_correct = 0
 
-        dispatch = torch.randint(len(other_models), (result.size(0),))
+        for m in other_models:
+            result = new_quizzes.clone()
 
-        for n, m in enumerate(other_models):
             masked_inplace_autoregression(
                 m,
                 self.batch_size,
-                result[dispatch == n],
-                ar_mask[dispatch == n],
-                deterministic_synthesis=False,
-                progress_bar_desc=None,
+                result,
+                ar_mask,
+                deterministic_synthesis=True,
+                progress_bar_desc="solving quizzes",
                 device=self.device,
             )
 
-        nb_correct = (
-            (input == result).long().min(dim=-1).values.reshape(-1, nb_runs).sum(dim=-1)
-        )
+            nb_correct += (new_quizzes == result).long().min(dim=-1).values
 
         return new_quizzes, nb_correct
index 89833e6..118a470 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -41,7 +41,7 @@ def generate(
     f_end = torch.zeros(nb, height, width, dtype=torch.int64)
     n = torch.arange(f_start.size(0))
 
-    for n in range(nb):
+    for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
         nb_fish = torch.randint(max_nb_obj, (1,)).item() + 1
         for c in torch.randperm(colors.size(0) - 2)[:nb_fish].sort().values:
             i, j = (