From 3060e09fe9c6d71f44482308c5876078c527bd70 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 3 Jul 2024 20:04:17 +0300 Subject: [PATCH] Update. --- main.py | 6 +++--- lang.py => reasoning.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) rename lang.py => reasoning.py (98%) diff --git a/main.py b/main.py index fe010ce..8b5d9a4 100755 --- a/main.py +++ b/main.py @@ -13,7 +13,7 @@ from torch.nn import functional as F import ffutils import mygpt -import sky, lang, quizz_machine +import sky, reasoning, quizz_machine # world quizzes vs. culture quizzes @@ -249,8 +249,8 @@ if args.problem == "sky": nb_iterations=args.sky_nb_iterations, speed=args.sky_speed, ) -elif args.problem == "lang": - problem = lang.Lang() +elif args.problem == "reasoning": + problem = reasoning.Reasoning() else: raise ValueError diff --git a/lang.py b/reasoning.py similarity index 98% rename from lang.py rename to reasoning.py index 1472e04..92699e8 100755 --- a/lang.py +++ b/reasoning.py @@ -17,7 +17,7 @@ from torch.nn import functional as F import problem -class Lang(problem.Problem): +class Reasoning(problem.Problem): named_colors = [ ("white", [255, 255, 255]), ("red", [255, 0, 0]), @@ -340,17 +340,17 @@ class Lang(problem.Problem): if __name__ == "__main__": import time - lang = Lang() + reasoning = Reasoning() start_time = time.perf_counter() - prompts, answers = lang.generate_prompts_and_answers(100) + prompts, answers = reasoning.generate_prompts_and_answers(100) delay = time.perf_counter() - start_time print(f"{prompts.size(0)/delay:02f} seq/s") # predicted_prompts = torch.rand(prompts.size(0)) < 0.5 # predicted_answers = torch.logical_not(predicted_prompts) - lang.save_quizzes( + reasoning.save_quizzes( "/tmp", "test", prompts[:36], -- 2.39.5