X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=reasoning.py;h=2874adc9c1c3f213bafce0c8dac53063c6463ab9;hb=f3ea2cbe833ff672ca5c41b98d583883ca233023;hp=768c15c40fdb7812e6a0c1eb8fe475a5e9ff7cf1;hpb=f4d12501685fe9b46a75e3768115f86ea9b75fa6;p=culture.git diff --git a/reasoning.py b/reasoning.py index 768c15c..2874adc 100755 --- a/reasoning.py +++ b/reasoning.py @@ -27,9 +27,9 @@ class Reasoning(problem.Problem): ("cyan", [0, 255, 255]), ("violet", [255, 0, 255]), ("lightgreen", [192, 255, 192]), - ("pink", [255, 192, 192]), + ("brown", [165, 42, 42]), ("lightblue", [192, 192, 255]), - ("gray", [192, 192, 192]), + ("gray", [128, 128, 128]), ] def __init__(self, device=torch.device("cpu")): @@ -294,7 +294,7 @@ class Reasoning(problem.Problem): f_X[i1:i2, j1:j2] = c[n if n > 0 else -1] def task_move(self, A, f_A, B, f_B): - di, dj = torch.randint(2, (2,)) * 2 - 1 + di, dj = torch.randint(3, (2,)) - 1 nb_rec = 3 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 for X, f_X in [(A, f_A), (B, f_B)]: @@ -467,8 +467,8 @@ if __name__ == "__main__": delay = time.perf_counter() - start_time print(f"{prompts.size(0)/delay:02f} seq/s") - # predicted_prompts = torch.rand(prompts.size(0)) < 0.5 - # predicted_answers = torch.logical_not(predicted_prompts) + predicted_prompts = torch.rand(prompts.size(0)) < 0.5 + predicted_answers = torch.logical_not(predicted_prompts) reasoning.save_quizzes( "/tmp", @@ -476,5 +476,6 @@ if __name__ == "__main__": prompts[:64], answers[:64], # You can add a bool to put a frame around the predicted parts - # predicted_prompts, predicted_answers + predicted_prompts[:64], + predicted_answers[:64], )