From: François Fleuret Date: Wed, 3 Jul 2024 16:54:46 +0000 (+0300) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=b2d0cd51d40fb16eeba8cc620be99cdd77d593d0;p=culture.git Update. --- diff --git a/lang.py b/lang.py index 43550d7..abb7ca2 100755 --- a/lang.py +++ b/lang.py @@ -34,13 +34,11 @@ class Lang(problem.Problem): def __init__( self, - nb_iterations=2, ): self.colors = torch.tensor([c for _, c in self.named_colors]) self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)]) self.height = 10 self.width = 10 - self.nb_iterations = nb_iterations ###################################################################### @@ -173,8 +171,9 @@ class Lang(problem.Problem): return len(self.colors) def rec_coo(self, x, n, min_height=3, min_width=3): + collision = x.new(x.size()) while True: - collision = x.new_zeros(x.size()) + collision[...] = 0 result = [] for _ in range(n): while True: @@ -263,13 +262,14 @@ class Lang(problem.Problem): r = self.rec_coo(X, N) for n in range(N): i1, j1, i2, j2 = r[n] - X[i1 : (i1 + i2) // 2, j1:j2] = c[2 * n] - f_X[i1 : (i1 + i2) // 2, j1:j2] = c[2 * n] - X[(i1 + i2) // 2 : (i1 + i2) // 2 + 1, j1:j2] = c[2 * n + 1] + i = (i1 + i2) // 2 + X[i1:i2, j1:j2] = c[2 * n] + X[i : i + 1, j1:j2] = c[2 * n + 1] + f_X[i1:i2, j1:j2] = c[2 * n] if n == N - 1: - f_X[(i1 + i2) // 2 : i2, j1:j2] = c[2 * n + 1] + f_X[i:i2, j1:j2] = c[2 * n + 1] else: - f_X[(i1 + i2) // 2 : (i1 + i2) // 2 + 1, j1:j2] = c[2 * n + 1] + f_X[i : i + 1, j1:j2] = c[2 * n + 1] def task_frame(self, A, f_A, B, f_B): N = 3 @@ -301,7 +301,8 @@ class Lang(problem.Problem): f_A = prompt[:, 1 * w : 2 * w] B = prompt[:, 2 * w : 3 * w] f_B = answer - tasks[torch.randint(len(tasks), (1,))](A, f_A, B, f_B) + task = tasks[torch.randint(len(tasks), (1,))] + task(A, f_A, B, f_B) return prompts.flatten(1), answers.flatten(1) def save_quizzes( @@ -328,9 +329,12 @@ class Lang(problem.Problem): if __name__ == "__main__": import time - lang = Lang(nb_iterations=4) + lang = Lang() - prompts, answers = lang.generate_prompts_and_answers(36) + start_time = time.perf_counter() + prompts, answers = lang.generate_prompts_and_answers(100) + delay = time.perf_counter() - start_time + print(f"{prompts.size(0)/delay:02f} seq/s") # predicted_prompts = torch.rand(prompts.size(0)) < 0.5 # predicted_answers = torch.logical_not(predicted_prompts) @@ -338,8 +342,8 @@ if __name__ == "__main__": lang.save_quizzes( "/tmp", "test", - prompts, - answers, + prompts[:36], + answers[:36], # You can add a bool to put a frame around the predicted parts # predicted_prompts, predicted_answers ) diff --git a/main.py b/main.py index b4e7318..fe010ce 100755 --- a/main.py +++ b/main.py @@ -250,7 +250,7 @@ if args.problem == "sky": speed=args.sky_speed, ) elif args.problem == "lang": - problem = lang.Lang(nb_iterations=2) + problem = lang.Lang() else: raise ValueError