From 29c046c06ba14bf79738aba900dddd24cf8133c0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 9 Sep 2024 20:58:04 +0200 Subject: [PATCH] Update. --- grids.py | 141 ++++++++++++++++++++++++++++++---------- main.py | 195 ++++++++++++++++++++++++++++++------------------------- 2 files changed, 212 insertions(+), 124 deletions(-) diff --git a/grids.py b/grids.py index 73e722e..054ba35 100755 --- a/grids.py +++ b/grids.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, sys, tqdm, os, warnings, cairo +import math, sys, tqdm, os, warnings, cairo, re import torch, torchvision @@ -234,6 +234,51 @@ class Grids(problem.Problem): return ar_mask + def text2quiz(self, t): + chr2col = [ + (".", "white"), + ("r", "red"), + ("g", "green"), + ("b", "blue"), + ("y", "yellow"), + ("c", "cyan"), + ("v", "violet"), + ("l", "lightgreen"), + ("o", "brown"), + ("l", "lightblue"), + ("a", "gray"), + ] + + col2tok = dict([(c[0], n) for n, c in enumerate(self.named_colors)]) + chr2tok = dict([(c, col2tok[col]) for c, col in chr2col]) + + t = re.sub(r"#.*\n", "", t).strip() + l = t.replace("\n\n", ";").split(";") + + result = [] + + for t in l: + t = "".join(t.replace("\n", " ").strip().split(" ")) + t = torch.tensor([chr2tok[c] for c in t]) + t = t.reshape(10, 4, 10).permute(1, 0, 2).flatten(1) + t = torch.cat( + [ + torch.tensor( + [ + [self.token_A], + [self.token_f_A], + [self.token_B], + [self.token_f_B], + ] + ), + t, + ], + dim=1, + ) + result.append(t.flatten()[None, :]) + + return torch.cat(result, dim=0) + def indices_select(self, quizzes, quad_order=("A", "f_A", "B", "f_B")): S = self.height * self.width q = quizzes.reshape(quizzes.size(0), 4, S + 1) @@ -1798,41 +1843,65 @@ if __name__ == "__main__": grids = Grids() - # nb = 5 - # quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill]) - # print(quizzes) - # print(grids.get_order(quizzes)) - # quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B")) - # print("DEBUG2", quizzes) - # print(grids.get_order(quizzes)) - # print(quizzes) - - # i = torch.rand(quizzes.size(0)) < 0.5 - - # quizzes[i] = grids.reconfigure(quizzes[i], struct=("f_B", "f_A", "B", "A")) - - # j = grids.indices_select(quizzes, struct=("f_B", "f_A", "B", "A")) - - # print( - # i.equal(j), - # grids.get_order(quizzes[j]), - # grids.get_order(quizzes[j == False]), - # ) - - # exit(0) - - # nb = 1000 - # grids = problem.MultiThreadProblem( - # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1 - # ) - # time.sleep(10) - # start_time = time.perf_counter() - # prompts, answers = grids.generate_w_quizzes(nb) - # delay = time.perf_counter() - start_time - # print(f"{prompts.size(0)/delay:02f} seq/s") - # exit(0) - - # if True: + q = grids.text2quiz( + """ + +# the original + +vvvvaaaaa. rrrraaaaa. .......... .......... +vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa.... +vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa.... +vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa.... +....aaaaa. ....aaaaa. .vvvvv.... .rrrrr.... +.......... .......... .vvvvvvvvv .rrrrroooo +.......... .......... .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo + +vvvvaaaaa. rrrraaaaa. .......... .......... +vvvvaaaaa. rrrraaaaa. .......... .......... +vvvvaaaaa. rrrraaaaa. .......aaa .......aaa +vvvvaaaaa. rrrraaaaa. .......aaa .......aaa +....aaaaa. ....aaaaa. .vvvvv.aaa .rrrrr.aaa +.......... .......... .vvvvvvvvv .rrrrroooo +.......... .......... .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo + +# +# so what +# + +vvvv...... rrrr...... .......... .......... +vvvv...... rrrr...... .......... .......... +vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa +vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa +.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa +.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo +.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo + +vvvv...... rrrr...... .......... .......... +vvvv...... rrrr...... .......... .......... +vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa +vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa +.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa +.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo +.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo +""" + ) + + grids.save_quizzes_as_image("/tmp", "test.png", q, nrow=1) + + exit(0) + nb, nrow = 128, 4 # nb, nrow = 8, 2 diff --git a/main.py b/main.py index d914113..97d37ce 100755 --- a/main.py +++ b/main.py @@ -129,6 +129,8 @@ parser.add_argument("--dirty_debug", action="store_true", default=False) parser.add_argument("--test", type=str, default=None) +parser.add_argument("--quizzes", type=str, default=None) + ###################################################################### grids_tasks = ", ".join( @@ -1060,11 +1062,9 @@ def save_badness_statistics( ###################################################################### -def quiz_validation_paris(models, c_quizzes, local_device): +def quiz_validation(models, c_quizzes, local_device): nb_have_to_be_correct = 3 - nb_have_to_be_wrong = 1 - - nb_runs = 3 + nb_have_to_be_wrong = 3 nb_mistakes_to_be_wrong = 5 record_wrong = [] @@ -1073,47 +1073,22 @@ def quiz_validation_paris(models, c_quizzes, local_device): for i, model in enumerate(models): assert i == model.id # a bit of paranoia model = copy.deepcopy(model).to(local_device).eval() - correct, wrong = True, False - for _ in range(nb_runs): - n = model_ae_argmax_nb_mistakes(model, c_quizzes).long() - correct = correct & (n == 0) - wrong = wrong | (n >= nb_mistakes_to_be_wrong) - record_wrong.append(wrong[:, None]) - nb_correct += correct.long() - nb_wrong += wrong.long() - - # print("nb_correct", nb_correct) - - # print("nb_wrong", nb_wrong) - - to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong) - - wrong = torch.cat(record_wrong, dim=1) - return to_keep, wrong - - -def quiz_validation_berne(models, c_quizzes, local_device): - nb_have_to_be_correct = 3 - nb_have_to_be_wrong = 1 - nb_runs = 3 - - record_wrong = [] - nb_correct, nb_wrong = 0, 0 + correct, wrong = True, False - for i, model in enumerate(models): - assert i == model.id # a bit of paranoia - model = copy.deepcopy(model).to(local_device).eval() - log_probas = 0 - for _ in range(nb_runs): - log_probas += model_ae_proba_solutions( - model, c_quizzes, log_probas=True, reduce=False + for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: + mask_generate = quiz_machine.make_quiz_mask( + quizzes=c_quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) - probas = log_probas.exp() - correct = (probas <= 0.75).long().sum(dim=1) == 0 - wrong = ((probas <= 0.125).long().sum(dim=1) >= 5) & ( - log_probas.sum(dim=1).div(nb_runs).exp() <= 0.5 - ) + result = ae_generate( + model, + (1 - mask_generate) * c_quizzes, + mask_generate, + ) + nb_mistakes = (result != c_quizzes).long().sum(dim=1) + correct = correct & (nb_mistakes == 0) + wrong = wrong | (nb_mistakes >= nb_mistakes_to_be_wrong) + record_wrong.append(wrong[:, None]) nb_correct += correct.long() nb_wrong += wrong.long() @@ -1125,6 +1100,9 @@ def quiz_validation_berne(models, c_quizzes, local_device): return to_keep, wrong +###################################################################### + + def generate_ae_c_quizzes(models, nb, local_device=main_device): # To be thread-safe we must make copies @@ -1137,10 +1115,6 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): nb=args.inference_batch_size, quad_order=quad_order ).to(local_device) - mask_generate = quiz_machine.make_quiz_mask( - quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1) - ) - wanted_nb = nb nb_to_save = 256 nb_c_quizzes_per_model = torch.zeros(len(models), device=local_device) @@ -1155,15 +1129,31 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): model = copy_for_inference(models[torch.randint(len(models), (1,)).item()]) generator_id = model.id + mask_generate = quiz_machine.make_quiz_mask( + quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1) + ) + c_quizzes = ae_generate(model, template, mask_generate) + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + for quad in [(0, 1, 0, 0), (0, 0, 0, 1)]: + mask_generate = quiz_machine.make_quiz_mask( + quizzes=c_quizzes, + quad_order=("A", "f_A", "B", "f_B"), + quad_mask=quad, + ) + c_quizzes = ae_generate( + model, + (1 - mask_generate) * c_quizzes, + mask_generate, + ) + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + to_keep = quiz_machine.problem.trivial(c_quizzes) == False c_quizzes = c_quizzes[to_keep] if c_quizzes.size(0) > 0: - to_keep, record_wrong = quiz_validation_berne( - models, c_quizzes, local_device - ) + to_keep, record_wrong = quiz_validation(models, c_quizzes, local_device) q = c_quizzes[to_keep] if q.size(0) > 0: @@ -1199,51 +1189,36 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): c_quizzes = torch.cat(record_c_quizzes, dim=0) agreements = torch.cat(record_agreements, dim=0) - subset_c_quizzes = c_quizzes[:nb_to_save] - - #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - # for model in models: - # for r in range(3): - # filename = f"culture_c_quiz_{n_epoch:04d}_prediction_{model.id}_{r}.png" - # p = model_ae_argmax_predictions(copy_for_inference(model), subset_c_quizzes) - # quiz_machine.problem.save_quizzes_as_image( - # args.result_dir, - # filename, - # quizzes=p, - # delta=True, - # nrow=8, - # ) - # log_string(f"wrote {filename}") - #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + return c_quizzes, agreements - filename = f"culture_c_quiz_{n_epoch:04d}.png" - l = [ - model_ae_proba_solutions(copy_for_inference(model), subset_c_quizzes) - for model in models - ] - probas = torch.cat([x[:, None] for x in l], dim=1) - comments = [] - - for l in probas: - comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) - - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=subset_c_quizzes, - comments=comments, - delta=True, - nrow=8, - ) +def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device): + record.append(generate_ae_c_quizzes(models, nb, local_device)) - log_string(f"wrote {filename}") - return c_quizzes, agreements +###################################################################### -def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device): - record.append(generate_ae_c_quizzes(models, nb, local_device)) +def save_c_quizzes_with_scores(models, c_quizzes, filename): + l = [model_ae_proba_solutions(model, c_quizzes) for model in models] + + probas = torch.cat([x[:, None] for x in l], dim=1) + + comments = [] + + for l in probas: + comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=subset_c_quizzes, + comments=comments, + delta=True, + nrow=8, + ) + + log_string(f"wrote {filename}") ###################################################################### @@ -1284,6 +1259,47 @@ nb_parameters = sum(p.numel() for p in models[0].parameters()) log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") +###################################################################### + +if args.quizzes is not None: + with open(args.quizzes, "r") as file: + txt = file.read() + + quizzes = quiz_machine.problem.text2quiz(txt) + + record = [] + + quizzes = quizzes.to(main_device) + for model in models: + log_string(f"processing {model.id} {args.quizzes}") + for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: + mask_generate = quiz_machine.make_quiz_mask( + quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad + ) + result = ae_generate( + model, + (1 - mask_generate) * quizzes, + mask_generate, + ) + record.append(result) + + result = torch.cat(record, dim=0) + + filename = "result.png" + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=result, + delta=True, + nrow=8, + ) + + log_string(f"wrote {filename}") + + exit(0) + + ###################################################################### last_n_epoch_c_quizzes = 0 @@ -1374,6 +1390,9 @@ for n_epoch in range(current_epoch, args.nb_epochs): # -------------------------------------------------------------------- + filename = f"culture_c_quiz_{n_epoch:04d}.png" + save_c_quizzes_with_scores(models, c_quizzes[:128], filename) + log_string(f"generated_c_quizzes {c_quizzes.size()=}") time_train = 0 -- 2.39.5