# Written by Francois Fleuret <francois@fleuret.org>
-import math, sys, tqdm, os, warnings
+import math, sys, tqdm, os, warnings, cairo
import torch, torchvision
######################################################################
+
+def text_img(height, width, text):
+ pixel_map = torch.full((height, width, 4), 255, dtype=torch.uint8)
+
+ surface = cairo.ImageSurface.create_for_data(
+ pixel_map.numpy(), cairo.FORMAT_ARGB32, pixel_map.size(1), pixel_map.size(0)
+ )
+
+ ctx = cairo.Context(surface)
+ ctx.set_source_rgb(0, 0, 0)
+ ctx.set_font_size(16)
+ ctx.select_font_face("courier", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
+ y = None
+ for line in text.split("\n"):
+ xbearing, ybearing, width, height, dx, dy = ctx.text_extents(line)
+ if y is None:
+ y = height * 1.5
+ x = height * 0.5
+
+ ctx.move_to(x, y)
+ ctx.show_text(line)
+ y += height * 1.5
+
+ ctx.stroke()
+
+ return pixel_map.permute(2, 0, 1)[None, :3].contiguous()
+
+
+######################################################################
+
import problem
max_nb_cached_chunks=None,
chunk_size=None,
nb_threads=-1,
- tasks=None,
+ world_tasks=None,
+ science_tasks=None,
):
self.colors = torch.tensor([c for _, c in self.named_colors])
self.cache_rec_coo = {}
- all_tasks = [
+ self.all_tasks = [
self.task_replace_color,
self.task_translate,
self.task_grow,
# self.task_islands, # TOO MESSY
]
- if tasks is None:
- self.all_tasks = all_tasks
+ if world_tasks is None:
+ self.world_tasks = self.all_tasks
else:
- self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
+ self.world_tasks = [
+ getattr(self, "task_" + t) for t in world_tasks.split(",")
+ ]
+
+ if science_tasks is not None:
+ self.science_tasks = [
+ getattr(self, "task_" + t) for t in science_tasks.split(",")
+ ]
super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
quizzes,
predicted_parts=None,
correct_parts=None,
+ comments=None,
+ comment_height=64,
nrow=4,
margin=8,
):
img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
+ if comments is not None:
+ comment_img = [text_img(comment_height, img.size(3), t) for t in comments]
+ comment_img = torch.cat(comment_img, dim=0)
+ img = torch.cat([img, comment_img], dim=2)
+
image_name = os.path.join(result_dir, filename)
torchvision.utils.save_image(
return quizzes
- def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
+ def generate_w_quizzes_(self, nb, tasks=None, science=False, progress_bar=False):
S = self.height * self.width
if tasks is None:
- tasks = self.all_tasks
+ if science:
+ tasks = self.science_tasks
+ else:
+ tasks = self.world_tasks
quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
return quizzes
- def save_some_examples(self, result_dir):
+ def save_some_examples(self, result_dir, science=False):
nb, nrow = 128, 4
- for t in self.all_tasks:
+ tasks = self.science_tasks if science else self.world_tasks
+ for t in tasks:
print(t.__name__)
quizzes = self.generate_w_quizzes_(nb, tasks=[t])
self.save_quizzes_as_image(
nb, nrow = 128, 4
# nb, nrow = 8, 2
- # for t in grids.all_tasks:
+ # for t in grids.world_tasks:
+
for t in [grids.task_path]:
print(t.__name__)
quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
"/tmp",
t.__name__ + ".png",
quizzes,
+ comments=[f"{t.__name__} #{k}" for k in range(quizzes.size(0))],
)
# exit(0)
)
parser.add_argument(
- "--grids_tasks",
+ "--grids_world_tasks",
type=str,
default=None,
help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
)
+parser.add_argument(
+ "--grids_science_tasks",
+ type=str,
+ default=None,
+ help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+)
+
+assert (
+ len(
+ set(args.grids_world_tasks.split(","))
+ & set(args.grids_science_tasks.split(","))
+ )
+ == 0
+), "World and science task have to be disjoint"
+
######################################################################
parser.add_argument("--sky_height", type=int, default=6)
nb_threads=args.nb_threads,
)
back_accuracy = False
+
elif args.problem == "grids":
problem = grids.Grids(
max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
chunk_size=100,
nb_threads=args.nb_threads,
- tasks=args.grids_tasks,
+ world_tasks=args.grids_world_tasks,
+ science_tasks=args.grids_science_tasks,
)
back_accuracy = True
+
else:
raise ValueError
# fail(s)
# This is nb_quizzes x nb_models
- number_correct_responses = 0
- remains = [c_quizzes.size(0)]
+ number_correct_responses = 0
+ nb_remaining = [c_quizzes.size(0)]
for r in range(args.nb_rounds):
if c_quizzes.size(0) == 0:
c_quizzes = c_quizzes[to_keep]
number_correct_responses = number_correct_responses[to_keep]
- remains.append(c_quizzes.size(0))
+ nb_remaining.append(c_quizzes.size(0))
if c_quizzes.size(0) > 0:
nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
else:
e = "???"
- v = " ".join([str(n) for n in remains])
+ v = " ".join([str(n) for n in nb_remaining])
log_string(f"filter c_quizzes {v}")
log_string(
v_train = validated_quizzes[:nb_for_train]
quiz_machine.store_c_quizzes(v_train, for_train=True)
- quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_train), for_train=True)
v_test = validated_quizzes[nb_for_train:nb_to_validate]
quiz_machine.store_c_quizzes(v_test, for_train=False)
- quiz_machine.store_c_quizzes(quiz_machine.problem.p_a_flip(v_test), for_train=False)
######################################################################
# save images
vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
if vq.size(0) > 0:
- prefix = f"culture_c_quiz_{n_epoch:04d}"
-
number_correct_responses = 0
for r in range(args.nb_rounds):
number_correct_responses += quiz_machine.models_successes(models, vq)
- with open(os.path.join(args.result_dir, prefix + "_responses.dat"), "w") as f:
- for n, r in enumerate(number_correct_responses):
- v = " ".join([str(n.item()) for n in r])
- f.write(f"{n}: {v}\n")
+ comments = []
+ for r in number_correct_responses:
+ comments.append("nb_correct " + " ".join([str(n.item()) for n in r]))
vq = quiz_machine.problem.reconfigure(vq, ("A", "f_A", "B", "f_B"))
- quiz_machine.problem.save_quizzes_as_image(args.result_dir, prefix, vq)
+ filename = f"culture_c_quiz_{n_epoch:04d}.png"
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir, filename, vq, comments=comments
+ )
######################################################################
model.main_test_accuracy = 0.0
model.id = k
- quiz_machine.create_w_quizzes(
- model=model,
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
+ model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes(
+ args.nb_train_samples
)
+ model.test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
+
models.append(model)
######################################################################
+science_w_quizzes = quiz_machine.problem.generate_w_quizzes(
+ args.nb_test_samples, science=True
+)
+
+######################################################################
+
current_epoch = 0
if args.resume:
c_quizzes = quiz_machine.problem.reconfigure(
c_quizzes, ("A", "f_A", "B", "f_B")
)
+
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,
f"non_validated_{n_epoch:04d}_{model.id:02d}.png",
from_w = torch.arange(
quizzes.size(0), device=quizzes.device
) < w_quizzes.size(0)
- i = torch.randperm(quizzes.size(0), device=quizzes.device)
-
- return quizzes[i], from_w[i]
else:
- return w_quizzes, torch.full(
- (w_quizzes.size(0),), True, device=w_quizzes.device
- )
+ quizzes = w_quizzes.clone()
+ from_w = torch.full((quizzes.size(0),), True, device=quizzes.device)
+
+ self.randomize_configuations_inplace(quizzes, structs=self.train_struct)
+
+ i = torch.randperm(quizzes.size(0), device=quizzes.device)
+
+ return quizzes[i], from_w[i]
######################################################################
input=result,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
- deterministic_synthesis=False,
progress_bar_desc="accuracy",
device=self.device,
)
result = input.new(input.size())
correct = input.new(input.size(0))
predicted_parts = input.new(input.size(0), 4)
+
nb = 0
+
for struct, mask in [
(("A", "f_A", "B", "f_B"), (0, 0, 0, 1)),
(("f_A", "A", "f_B", "B"), (0, 0, 0, 1)),
quizzes[r == c], struct=structs[c]
)
- def create_w_quizzes(self, model, nb_train_samples, nb_test_samples):
- model.train_w_quizzes = self.problem.generate_w_quizzes(nb_train_samples)
- model.test_w_quizzes = self.problem.generate_w_quizzes(nb_test_samples)
-
- self.randomize_configuations_inplace(
- model.train_w_quizzes, structs=self.train_struct
- )
-
- self.randomize_configuations_inplace(
- model.test_w_quizzes, structs=self.train_struct
- )
-
######################################################################
def renew_train_w_quizzes(self, model):
if hasattr(model, "hard_w_quizzes"):
- self.logger(
- f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
- )
-
if model.hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0):
+ nb_to_generate = 0
model.train_w_quizzes[...] = model.hard_w_quizzes[
torch.randperm(hard_w_quizzes.size(0))[
model.train_w_quizzes.size(0)
]
]
else:
+ nb_to_generate = model.train_w_quizzes.size(
+ 0
+ ) - model.hard_w_quizzes.size(0)
model.train_w_quizzes[...] = torch.cat(
[
model.hard_w_quizzes,
- self.problem.generate_w_quizzes(
- model.train_w_quizzes.size(0) - model.hard_w_quizzes.size(0)
- ),
+ self.problem.generate_w_quizzes(nb_to_generate),
],
dim=0,
)
else:
+ nb_to_generate = 0
model.train_w_quizzes[...] = self.problem.generate_w_quizzes(
model.train_w_quizzes.size(0)
)
+ self.logger(
+ f"re-using {model.hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
+ )
+
self.randomize_configuations_inplace(
model.train_w_quizzes, structs=self.train_struct
)
input=result,
ar_mask=ar_mask,
seq_logproba=seq_logproba[:, model.id],
- deterministic_synthesis=False,
device=self.device,
)
input=result,
ar_mask=ar_mask,
seq_logproba=seq_logproba[:, model.id],
- deterministic_synthesis=False,
device=self.device,
)
temperature_hot=1.0,
temperature_cold=1.0,
):
- c_quizzes = self.problem.create_empty_quizzes(nb, ("f_B", "f_A", "A", "B")).to(
- self.device
- )
+ c_quizzes = self.problem.create_empty_quizzes(nb, ("f_B", "f_A", "A", "B"))
+ c_quizzes = c_quizzes.to(self.device)
seq_logproba = torch.zeros(nb, device=self.device)
),
seq_logproba=seq_logproba,
logit_transformer=lt_noisy,
- deterministic_synthesis=False,
device=self.device,
)
),
seq_logproba=seq_logproba,
logit_transformer=lt_clean,
- deterministic_synthesis=False,
device=self.device,
)
),
seq_logproba=seq_logproba,
logit_transformer=lt_clean,
- deterministic_synthesis=False,
device=self.device,
)