Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 17:04:17 +0000 (20:04 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 3 Jul 2024 17:04:17 +0000 (20:04 +0300)
main.py
reasoning.py [moved from lang.py with 98% similarity]

diff --git a/main.py b/main.py
index fe010ce..8b5d9a4 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -13,7 +13,7 @@ from torch.nn import functional as F
 
 import ffutils
 import mygpt
-import sky, lang, quizz_machine
+import sky, reasoning, quizz_machine
 
 # world quizzes vs. culture quizzes
 
@@ -249,8 +249,8 @@ if args.problem == "sky":
         nb_iterations=args.sky_nb_iterations,
         speed=args.sky_speed,
     )
-elif args.problem == "lang":
-    problem = lang.Lang()
+elif args.problem == "reasoning":
+    problem = reasoning.Reasoning()
 else:
     raise ValueError
 
similarity index 98%
rename from lang.py
rename to reasoning.py
index 1472e04..92699e8 100755 (executable)
--- a/lang.py
@@ -17,7 +17,7 @@ from torch.nn import functional as F
 import problem
 
 
-class Lang(problem.Problem):
+class Reasoning(problem.Problem):
     named_colors = [
         ("white", [255, 255, 255]),
         ("red", [255, 0, 0]),
@@ -340,17 +340,17 @@ class Lang(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    lang = Lang()
+    reasoning = Reasoning()
 
     start_time = time.perf_counter()
-    prompts, answers = lang.generate_prompts_and_answers(100)
+    prompts, answers = reasoning.generate_prompts_and_answers(100)
     delay = time.perf_counter() - start_time
     print(f"{prompts.size(0)/delay:02f} seq/s")
 
     # predicted_prompts = torch.rand(prompts.size(0)) < 0.5
     # predicted_answers = torch.logical_not(predicted_prompts)
 
-    lang.save_quizzes(
+    reasoning.save_quizzes(
         "/tmp",
         "test",
         prompts[:36],