Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jul 2024 06:51:57 +0000 (09:51 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 6 Jul 2024 06:51:57 +0000 (09:51 +0300)
main.py
reasoning.py

diff --git a/main.py b/main.py
index 02e1a8d..ff573c4 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -13,7 +13,7 @@ from torch.nn import functional as F
 
 import ffutils
 import mygpt
-import sky, reasoning, quiz_machine
+import sky, grids, quiz_machine
 
 # world quizzes vs. culture quizzes
 
@@ -251,8 +251,8 @@ if args.problem == "sky":
         speed=args.sky_speed,
     )
     back_accuracy = False
-elif args.problem == "reasoning":
-    problem = reasoning.Reasoning(device=device)
+elif args.problem == "grids":
+    problem = grids.Grids(device=device)
     back_accuracy = True
 else:
     raise ValueError
index 058c410..9462f87 100755 (executable)
@@ -17,7 +17,7 @@ from torch.nn import functional as F
 import problem
 
 
-class Reasoning(problem.Problem):
+class Grids(problem.Problem):
     named_colors = [
         ("white", [255, 255, 255]),
         ("red", [255, 0, 0]),
@@ -421,7 +421,7 @@ class Reasoning(problem.Problem):
                 if n < nb_rec - 1:
                     f_X[i1, j1] = c[-1]
 
-    def contact(X, i, j, q):
+    def contact(self, X, i, j, q):
         nq, nq_diag = 0, 0
         no = 0
 
@@ -466,7 +466,7 @@ class Reasoning(problem.Problem):
             k = torch.randperm(self.height * self.width)
             for p in range(self.height * self.width):
                 i, j = k[p] % self.height, k[p] // self.height
-                no, nq, nq_diag = contact(X, i, j, c[q[p]])
+                no, nq, nq_diag = self.contact(X, i, j, c[q[p]])
                 if no == 0 and nq_diag == 0:
                     if nq == 0:
                         if nb[q[p]] < self.width:
@@ -693,19 +693,20 @@ if __name__ == "__main__":
 
     nb = 48
 
-    reasoning = Reasoning()
+    grids = Grids()
 
-    for t in [reasoning.task_islands]:  # reasoning.all_tasks():
+    for t in grids.all_tasks():
+        # for t in [grids.task_islands]:
         print(t.__name__)
-        prompts, answers = reasoning.generate_prompts_and_answers(nb, tasks=[t])
-        reasoning.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)
+        prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
+        grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)
 
     exit(0)
 
     nb = 72
 
     start_time = time.perf_counter()
-    prompts, answers = reasoning.generate_prompts_and_answers(nb)
+    prompts, answers = grids.generate_prompts_and_answers(nb)
     delay = time.perf_counter() - start_time
     print(f"{prompts.size(0)/delay:02f} seq/s")
 
@@ -713,7 +714,7 @@ if __name__ == "__main__":
     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
 
-    reasoning.save_quizzes(
+    grids.save_quizzes(
         "/tmp",
         "test",
         prompts[:nb],