Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 9 Jul 2024 08:47:47 +0000 (11:47 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 9 Jul 2024 08:47:47 +0000 (11:47 +0300)
grids.py
problem.py

index a2e253e..47e5861 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -194,9 +194,9 @@ class Grids(problem.Problem):
     def nb_token_values(self):
         return len(self.colors)
 
-    @torch.compile
+    @torch.compile
     def rec_coo_(self, nb_rec, min_height=3, min_width=3):
-        @torch.compile
+        @torch.compile
         def overlap(ia, ja, ib, jb):
             return (
                 ia[1] >= ib[0] and ia[0] <= ib[1] and ja[1] >= jb[0] and ja[0] <= jb[1]
@@ -226,7 +226,7 @@ class Grids(problem.Problem):
     # non-overlapping rectangles quickly, but made the generation of
     # 100k samples go from 1h50 with a lame pure python code to 3min30s
     # with this one.
-    @torch.compile
+    @torch.compile
     def rec_coo(self, nb_rec, min_height=3, min_width=3):
         nb_trials = 200
 
@@ -288,7 +288,7 @@ class Grids(problem.Problem):
             )
         ]
 
-    @torch.compile
+    @torch.compile
     def rec_coo_(self, x, n, min_height=3, min_width=3):
         collision = x.new(x.size())
         while True:
@@ -313,7 +313,7 @@ class Grids(problem.Problem):
 
     ######################################################################
 
-    @torch.compile
+    @torch.compile
     def task_replace_color(self, A, f_A, B, f_B):
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
@@ -324,7 +324,7 @@ class Grids(problem.Problem):
                 X[i1:i2, j1:j2] = c[n]
                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
 
-    @torch.compile
+    @torch.compile
     def task_translate(self, A, f_A, B, f_B):
         di, dj = torch.randint(3, (2,)) - 1
         nb_rec = 3
@@ -349,7 +349,7 @@ class Grids(problem.Problem):
                 else:
                     f_X[i1:i2, j1:j2] = c[n]
 
-    @torch.compile
+    @torch.compile
     def task_grow(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
         nb_rec = 3
@@ -375,7 +375,7 @@ class Grids(problem.Problem):
                     X[i1:i2, j1:j2] = c[n]
                     f_X[i1:i2, j1:j2] = c[n]
 
-    @torch.compile
+    @torch.compile
     def task_color_grow(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
         nb_rec = 3
@@ -417,7 +417,7 @@ class Grids(problem.Problem):
                     else:
                         f_X[i1:i2, j : j + 1] = c[2 * n + 1]
 
-    @torch.compile
+    @torch.compile
     def task_frame(self, A, f_A, B, f_B):
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
@@ -430,7 +430,7 @@ class Grids(problem.Problem):
                 if n == nb_rec - 1:
                     f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
 
-    @torch.compile
+    @torch.compile
     def task_detect(self, A, f_A, B, f_B):
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
@@ -442,7 +442,7 @@ class Grids(problem.Problem):
                 if n < nb_rec - 1:
                     f_X[i1, j1] = c[-1]
 
-    @torch.compile
+    @torch.compile
     def contact(self, X, i, j, q):
         nq, nq_diag = 0, 0
         no = 0
@@ -478,7 +478,7 @@ class Grids(problem.Problem):
 
         return no, nq, nq_diag
 
-    @torch.compile
+    @torch.compile
     def task_count(self, A, f_A, B, f_B):
         N = (torch.randint(4, (1,)) + 2).item()
         c = torch.randperm(len(self.colors) - 1)[:N] + 1
@@ -502,7 +502,7 @@ class Grids(problem.Problem):
                 for j in range(nb[n]):
                     f_X[n, j] = c[n]
 
-    @torch.compile
+    @torch.compile
     def task_trajectory(self, A, f_A, B, f_B):
         c = torch.randperm(len(self.colors) - 1)[:2] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
@@ -530,12 +530,11 @@ class Grids(problem.Problem):
                 f_X[i + k * di, j + k * dj] = c[min(k, 1)]
                 k += 1
 
-    @torch.compile
+    @torch.compile
     def task_bounce(self, A, f_A, B, f_B):
         c = torch.randperm(len(self.colors) - 1)[:3] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
-
-            @torch.compile
+            # @torch.compile
             def free(i, j):
                 return (
                     i >= 0
@@ -595,7 +594,7 @@ class Grids(problem.Problem):
                 if l > 3:
                     break
 
-    @torch.compile
+    @torch.compile
     def task_scale(self, A, f_A, B, f_B):
         c = torch.randperm(len(self.colors) - 1)[:2] + 1
 
@@ -620,7 +619,7 @@ class Grids(problem.Problem):
             X[i, j] = c[1]
             f_X[0:2, 0:2] = c[1]
 
-    @torch.compile
+    @torch.compile
     def task_symbols(self, A, f_A, B, f_B):
         nb_rec = 4
         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
@@ -656,7 +655,7 @@ class Grids(problem.Problem):
 
             f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
 
-    @torch.compile
+    @torch.compile
     def task_ortho(self, A, f_A, B, f_B):
         nb_rec = 3
         di, dj = torch.randint(3, (2,)) - 1
@@ -711,7 +710,7 @@ class Grids(problem.Problem):
                 ):
                     break
 
-    @torch.compile
+    @torch.compile
     def task_islands(self, A, f_A, B, f_B):
         pass
 
@@ -806,14 +805,25 @@ if __name__ == "__main__":
 
     grids = Grids()
 
-    if False:
-        nb = 8
+    # nb = 1000
+    # grids = problem.MultiThreadProblem(
+    # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
+    # )
+    #    time.sleep(10)
+    # start_time = time.perf_counter()
+    # prompts, answers = grids.generate_prompts_and_answers(nb)
+    # delay = time.perf_counter() - start_time
+    # print(f"{prompts.size(0)/delay:02f} seq/s")
+    # exit(0)
+
+    if True:
+        nb = 72
 
         for t in grids.all_tasks():
             # for t in [grids.task_ortho]:
             print(t.__name__)
             prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
-            grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=2)
+            grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)
 
         exit(0)
 
index 7dd60dc..a49634d 100755 (executable)
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import threading, queue, torch
+import threading, queue, torch, tqdm
 
 
 class Problem:
@@ -33,11 +33,12 @@ class Problem:
 
 
 class MultiThreadProblem:
-    def __init__(self, problem, max_nb_cached_chunks, chunk_size):
+    def __init__(self, problem, max_nb_cached_chunks, chunk_size, nb_threads=1):
         self.problem = problem
         self.chunk_size = chunk_size
         self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
-        threading.Thread(target=self.fill_cache, daemon=True).start()
+        for _ in range(nb_threads):
+            threading.Thread(target=self.fill_cache, daemon=True).start()
         self.rest = None
 
     def nb_token_values(self):
@@ -67,7 +68,7 @@ class MultiThreadProblem:
                 self.chunk_size
             )
 
-            self.queue.put((prompts, answers), block=True)
+            self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
 
     def trivial_prompts_and_answers(self, prompts, answers):
         return self.problem.trivial_prompts_and_answers(prompts, answers)
@@ -82,13 +83,20 @@ class MultiThreadProblem:
 
         n = sum([p.size(0) for p in prompts])
 
-        while n < nb:
-            p, s = self.queue.get(block=True)
-            prompts.append(p)
-            answers.append(s)
-            n += p.size(0)
+        with tqdm.tqdm(
+            total=nb,
+            dynamic_ncols=True,
+            desc="world generation",
+        ) as pbar:
+            while n < nb:
+                p, s = self.queue.get(block=True)
+                prompts.append(p)
+                answers.append(s)
+                n += p.size(0)
+                pbar.update(p.size(0))
 
         prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
+        assert n == prompts.size(0)
 
         k = n - nb