# Written by Francois Fleuret <francois@fleuret.org>
-import math, sys, tqdm, os, warnings, cairo
+import math, sys, tqdm, os, warnings, cairo, re
import torch, torchvision
return ar_mask
+ def text2quiz(self, t):
+ chr2col = [
+ (".", "white"),
+ ("r", "red"),
+ ("g", "green"),
+ ("b", "blue"),
+ ("y", "yellow"),
+ ("c", "cyan"),
+ ("v", "violet"),
+ ("l", "lightgreen"),
+ ("o", "brown"),
+ ("l", "lightblue"),
+ ("a", "gray"),
+ ]
+
+ col2tok = dict([(c[0], n) for n, c in enumerate(self.named_colors)])
+ chr2tok = dict([(c, col2tok[col]) for c, col in chr2col])
+
+ t = re.sub(r"#.*\n", "", t).strip()
+ l = t.replace("\n\n", ";").split(";")
+
+ result = []
+
+ for t in l:
+ t = "".join(t.replace("\n", " ").strip().split(" "))
+ t = torch.tensor([chr2tok[c] for c in t])
+ t = t.reshape(10, 4, 10).permute(1, 0, 2).flatten(1)
+ t = torch.cat(
+ [
+ torch.tensor(
+ [
+ [self.token_A],
+ [self.token_f_A],
+ [self.token_B],
+ [self.token_f_B],
+ ]
+ ),
+ t,
+ ],
+ dim=1,
+ )
+ result.append(t.flatten()[None, :])
+
+ return torch.cat(result, dim=0)
+
def indices_select(self, quizzes, quad_order=("A", "f_A", "B", "f_B")):
S = self.height * self.width
q = quizzes.reshape(quizzes.size(0), 4, S + 1)
grids = Grids()
- # nb = 5
- # quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
- # print(quizzes)
- # print(grids.get_order(quizzes))
- # quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
- # print("DEBUG2", quizzes)
- # print(grids.get_order(quizzes))
- # print(quizzes)
-
- # i = torch.rand(quizzes.size(0)) < 0.5
-
- # quizzes[i] = grids.reconfigure(quizzes[i], struct=("f_B", "f_A", "B", "A"))
-
- # j = grids.indices_select(quizzes, struct=("f_B", "f_A", "B", "A"))
-
- # print(
- # i.equal(j),
- # grids.get_order(quizzes[j]),
- # grids.get_order(quizzes[j == False]),
- # )
-
- # exit(0)
-
- # nb = 1000
- # grids = problem.MultiThreadProblem(
- # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
- # )
- # time.sleep(10)
- # start_time = time.perf_counter()
- # prompts, answers = grids.generate_w_quizzes(nb)
- # delay = time.perf_counter() - start_time
- # print(f"{prompts.size(0)/delay:02f} seq/s")
- # exit(0)
-
- # if True:
+ q = grids.text2quiz(
+ """
+
+# the original
+
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+....aaaaa. ....aaaaa. .vvvvv.... .rrrrr....
+.......... .......... .vvvvvvvvv .rrrrroooo
+.......... .......... .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. .......aaa .......aaa
+vvvvaaaaa. rrrraaaaa. .......aaa .......aaa
+....aaaaa. ....aaaaa. .vvvvv.aaa .rrrrr.aaa
+.......... .......... .vvvvvvvvv .rrrrroooo
+.......... .......... .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+#
+# so what
+#
+
+vvvv...... rrrr...... .......... ..........
+vvvv...... rrrr...... .......... ..........
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+vvvv...... rrrr...... .......... ..........
+vvvv...... rrrr...... .......... ..........
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+"""
+ )
+
+ grids.save_quizzes_as_image("/tmp", "test.png", q, nrow=1)
+
+ exit(0)
+
nb, nrow = 128, 4
# nb, nrow = 8, 2
parser.add_argument("--test", type=str, default=None)
+parser.add_argument("--quizzes", type=str, default=None)
+
######################################################################
grids_tasks = ", ".join(
######################################################################
-def quiz_validation_paris(models, c_quizzes, local_device):
+def quiz_validation(models, c_quizzes, local_device):
nb_have_to_be_correct = 3
- nb_have_to_be_wrong = 1
-
- nb_runs = 3
+ nb_have_to_be_wrong = 3
nb_mistakes_to_be_wrong = 5
record_wrong = []
for i, model in enumerate(models):
assert i == model.id # a bit of paranoia
model = copy.deepcopy(model).to(local_device).eval()
- correct, wrong = True, False
- for _ in range(nb_runs):
- n = model_ae_argmax_nb_mistakes(model, c_quizzes).long()
- correct = correct & (n == 0)
- wrong = wrong | (n >= nb_mistakes_to_be_wrong)
- record_wrong.append(wrong[:, None])
- nb_correct += correct.long()
- nb_wrong += wrong.long()
-
- # print("nb_correct", nb_correct)
-
- # print("nb_wrong", nb_wrong)
-
- to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
-
- wrong = torch.cat(record_wrong, dim=1)
- return to_keep, wrong
-
-
-def quiz_validation_berne(models, c_quizzes, local_device):
- nb_have_to_be_correct = 3
- nb_have_to_be_wrong = 1
- nb_runs = 3
-
- record_wrong = []
- nb_correct, nb_wrong = 0, 0
+ correct, wrong = True, False
- for i, model in enumerate(models):
- assert i == model.id # a bit of paranoia
- model = copy.deepcopy(model).to(local_device).eval()
- log_probas = 0
- for _ in range(nb_runs):
- log_probas += model_ae_proba_solutions(
- model, c_quizzes, log_probas=True, reduce=False
+ for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
+ mask_generate = quiz_machine.make_quiz_mask(
+ quizzes=c_quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
- probas = log_probas.exp()
- correct = (probas <= 0.75).long().sum(dim=1) == 0
- wrong = ((probas <= 0.125).long().sum(dim=1) >= 5) & (
- log_probas.sum(dim=1).div(nb_runs).exp() <= 0.5
- )
+ result = ae_generate(
+ model,
+ (1 - mask_generate) * c_quizzes,
+ mask_generate,
+ )
+ nb_mistakes = (result != c_quizzes).long().sum(dim=1)
+ correct = correct & (nb_mistakes == 0)
+ wrong = wrong | (nb_mistakes >= nb_mistakes_to_be_wrong)
+
record_wrong.append(wrong[:, None])
nb_correct += correct.long()
nb_wrong += wrong.long()
return to_keep, wrong
+######################################################################
+
+
def generate_ae_c_quizzes(models, nb, local_device=main_device):
# To be thread-safe we must make copies
nb=args.inference_batch_size, quad_order=quad_order
).to(local_device)
- mask_generate = quiz_machine.make_quiz_mask(
- quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
- )
-
wanted_nb = nb
nb_to_save = 256
nb_c_quizzes_per_model = torch.zeros(len(models), device=local_device)
model = copy_for_inference(models[torch.randint(len(models), (1,)).item()])
generator_id = model.id
+ mask_generate = quiz_machine.make_quiz_mask(
+ quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
+ )
+
c_quizzes = ae_generate(model, template, mask_generate)
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ for quad in [(0, 1, 0, 0), (0, 0, 0, 1)]:
+ mask_generate = quiz_machine.make_quiz_mask(
+ quizzes=c_quizzes,
+ quad_order=("A", "f_A", "B", "f_B"),
+ quad_mask=quad,
+ )
+ c_quizzes = ae_generate(
+ model,
+ (1 - mask_generate) * c_quizzes,
+ mask_generate,
+ )
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
to_keep = quiz_machine.problem.trivial(c_quizzes) == False
c_quizzes = c_quizzes[to_keep]
if c_quizzes.size(0) > 0:
- to_keep, record_wrong = quiz_validation_berne(
- models, c_quizzes, local_device
- )
+ to_keep, record_wrong = quiz_validation(models, c_quizzes, local_device)
q = c_quizzes[to_keep]
if q.size(0) > 0:
c_quizzes = torch.cat(record_c_quizzes, dim=0)
agreements = torch.cat(record_agreements, dim=0)
- subset_c_quizzes = c_quizzes[:nb_to_save]
-
- #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
- # for model in models:
- # for r in range(3):
- # filename = f"culture_c_quiz_{n_epoch:04d}_prediction_{model.id}_{r}.png"
- # p = model_ae_argmax_predictions(copy_for_inference(model), subset_c_quizzes)
- # quiz_machine.problem.save_quizzes_as_image(
- # args.result_dir,
- # filename,
- # quizzes=p,
- # delta=True,
- # nrow=8,
- # )
- # log_string(f"wrote {filename}")
- #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ return c_quizzes, agreements
- filename = f"culture_c_quiz_{n_epoch:04d}.png"
- l = [
- model_ae_proba_solutions(copy_for_inference(model), subset_c_quizzes)
- for model in models
- ]
- probas = torch.cat([x[:, None] for x in l], dim=1)
- comments = []
-
- for l in probas:
- comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
-
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- filename,
- quizzes=subset_c_quizzes,
- comments=comments,
- delta=True,
- nrow=8,
- )
+def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
+ record.append(generate_ae_c_quizzes(models, nb, local_device))
- log_string(f"wrote {filename}")
- return c_quizzes, agreements
+######################################################################
-def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
- record.append(generate_ae_c_quizzes(models, nb, local_device))
+def save_c_quizzes_with_scores(models, c_quizzes, filename):
+ l = [model_ae_proba_solutions(model, c_quizzes) for model in models]
+
+ probas = torch.cat([x[:, None] for x in l], dim=1)
+
+ comments = []
+
+ for l in probas:
+ comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ quizzes=subset_c_quizzes,
+ comments=comments,
+ delta=True,
+ nrow=8,
+ )
+
+ log_string(f"wrote {filename}")
######################################################################
log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+######################################################################
+
+if args.quizzes is not None:
+ with open(args.quizzes, "r") as file:
+ txt = file.read()
+
+ quizzes = quiz_machine.problem.text2quiz(txt)
+
+ record = []
+
+ quizzes = quizzes.to(main_device)
+ for model in models:
+ log_string(f"processing {model.id} {args.quizzes}")
+ for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
+ mask_generate = quiz_machine.make_quiz_mask(
+ quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+ )
+ result = ae_generate(
+ model,
+ (1 - mask_generate) * quizzes,
+ mask_generate,
+ )
+ record.append(result)
+
+ result = torch.cat(record, dim=0)
+
+ filename = "result.png"
+
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ quizzes=result,
+ delta=True,
+ nrow=8,
+ )
+
+ log_string(f"wrote {filename}")
+
+ exit(0)
+
+
######################################################################
last_n_epoch_c_quizzes = 0
# --------------------------------------------------------------------
+ filename = f"culture_c_quiz_{n_epoch:04d}.png"
+ save_c_quizzes_with_scores(models, c_quizzes[:128], filename)
+
log_string(f"generated_c_quizzes {c_quizzes.size()=}")
time_train = 0