From: François Fleuret Date: Mon, 15 Jul 2024 14:00:31 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=03257cc01488588246fe23eabf54acaa2ac32442;p=culture.git Update. --- diff --git a/grids.py b/grids.py index 7752136..a115f93 100755 --- a/grids.py +++ b/grids.py @@ -1126,7 +1126,7 @@ class Grids(problem.Problem): ) def save_some_examples(self, result_dir): - nb, nrow = 72, 4 + nb, nrow = 128, 4 for t in self.all_tasks: print(t.__name__) prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t]) @@ -1155,7 +1155,7 @@ if __name__ == "__main__": # exit(0) # if True: - nb, nrow = 72, 4 + nb, nrow = 128, 4 # nb, nrow = 8, 2 # for t in grids.all_tasks: diff --git a/main.py b/main.py index 4673f42..b372f12 100755 --- a/main.py +++ b/main.py @@ -372,12 +372,12 @@ def one_epoch(model, quiz_machine, local_device=main_device): # token_logprobas are NxMxT where M is the number of models +# def compute_valid_quizzes_(token_logprobas): +# warnings.warn("validation with uniform constraints", RuntimeWarning) +# l = token_logprobas.min(dim=-1).values.sort(dim=-1).values +# return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5)) - -def compute_valid_quizzes_(token_logprobas): - warnings.warn("validation with uniform constraints", RuntimeWarning) - l = token_logprobas.min(dim=-1).values.sort(dim=-1).values - return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5)) +# token_logprobas are NxMxT where M is the number of models def compute_valid_quizzes(token_logprobas):