Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 28 Mar 2024 13:55:14 +0000 (14:55 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 28 Mar 2024 13:55:14 +0000 (14:55 +0100)
main.py
tasks.py

diff --git a/main.py b/main.py
index 0f2cb61..9437136 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -706,7 +706,7 @@ if args.task == "expr" and args.expr_input_file is not None:
 # Compute the entropy of the training tokens
 
 token_count = 0
-for input in task.batches(split="train"):
+for input in task.batches(split="train", desc="train-entropy"):
     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
 token_probas = token_count / token_count.sum()
 entropy = -torch.xlogy(token_probas, token_probas).sum()
@@ -728,9 +728,13 @@ if args.max_percents_of_test_in_train >= 0:
         yield s
 
     nb_test, nb_in_train = 0, 0
-    for test_subset in subsets_as_tuples(task.batches(split="test"), 25000):
+    for test_subset in subsets_as_tuples(
+        task.batches(split="test", desc="test-check"), 25000
+    ):
         in_train = set()
-        for train_subset in subsets_as_tuples(task.batches(split="train"), 25000):
+        for train_subset in subsets_as_tuples(
+            task.batches(split="train", desc="train-check"), 25000
+        ):
             in_train.update(test_subset.intersection(train_subset))
         nb_in_train += len(in_train)
         nb_test += len(test_subset)
index 324376d..3ef64d7 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -1944,7 +1944,7 @@ class Greed(Task):
                 progress_bar_desc=None,
             )
             warnings.warn("keeping thinking snapshots", RuntimeWarning)
-            snapshots.append(result[:10].detach().clone())
+            snapshots.append(result[:100].detach().clone())
 
         # Generate iteration after iteration
 
@@ -1986,11 +1986,11 @@ class Greed(Task):
             # Set the lookahead_reward to UNKNOWN for the next iterations
             result[
                 :, u + self.world.index_lookahead_reward
-            ] = self.world.lookahead_reward2code(gree.REWARD_UNKNOWN)
+            ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
 
         filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt")
         with open(filename, "w") as f:
-            for n in range(10):
+            for n in range(snapshots[0].size(0)):
                 for s in snapshots:
                     lr, s, a, r = self.world.seq2episodes(
                         s[n : n + 1],