class Grids(problem.Problem):
- # grid_gray=64
+ grid_gray = 64
+ thickness = 1
+ background_gray = 255
+
+ # grid_gray=240
# thickness=1
- # background_gray=255
+ # background_gray=240
- grid_gray = 255
- thickness = 0
- background_gray = grid_gray
+ # grid_gray = 255
+ # thickness = 0
+ # background_gray = 240
named_colors = [
("white", [background_gray, background_gray, background_gray]),
# ("white", [224, 224, 224]),
("red", [255, 0, 0]),
- ("green", [0, 192, 0]),
+ ("green", [0, 160, 0]),
("blue", [0, 0, 255]),
("yellow", [255, 224, 0]),
("cyan", [0, 255, 255]),
("violet", [224, 128, 255]),
- ("lightgreen", [192, 255, 192]),
+ ("lightgreen", [160, 255, 160]),
("brown", [165, 42, 42]),
("lightblue", [192, 192, 255]),
("gray", [128, 128, 128]),
parser.add_argument("--nb_test_samples", type=int, default=1000)
-parser.add_argument("--nb_c_quizzes", type=int, default=10000)
+parser.add_argument("--nb_c_quizzes", type=int, default=5000)
parser.add_argument("--c_quiz_multiplier", type=int, default=1)
parser.add_argument("--nb_have_to_be_wrong", type=int, default=1)
-parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=10)
+parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5)
# ----------------------------------
def log_string(s):
- """print the given string prefixed with a time stamps, and log it into log_file is not None"""
+ """print the given string prefixed with a time stamps, and log it
+ into log_file is not None"""
t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
######################################################################
-# If we need to move an optimizer to a different device
-
def optimizer_to(optim, device):
+ """Move the optimizer optim to the device"""
for param in optim.state.values():
# Not sure there are any global tensors in the state dict
if isinstance(param, torch.Tensor):
######################################################################
-# Make args.nb_hints holes in the mask and copy the corresponding cell
-# values from the target to the input
-
-
def add_hints_imt(imt_set):
+ """Set every component of the mask to zero with probability
+ args.proba_hint, and for each component set to zero, copy the
+ corresponding value from the target into the input
+
+ """
input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
# h = torch.rand(masks.size(), device=masks.device) - masks
# t = h.sort(dim=1).values[:, args.nb_hints, None]
return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
-# Make pixels from the available input (mask=0) noise with probability
-# args.proba_prompt_noise
-
-
def add_noise_imt(imt_set):
+ """Replace every component of the input by a random value with
+ probability args.proba_prompt_noise."""
input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
noise = quiz_machine.pure_noise(input.size(0), input.device)
change = (1 - masks) * (
models = []
for i in range(args.nb_models):
- model = attae.FunctionalAttentionAE(
- # model = attae.AttentionAE(
+ # model = attae.FunctionalAttentionAE(
+ model = attae.AttentionAE(
vocabulary_size=vocabulary_size * 2,
dim_model=args.dim_model,
dim_keys=args.dim_keys,
######################################################################
-def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device):
+def evaluate_quizzes(quizzes, models, with_perturbations, local_device):
nb_correct, nb_wrong = 0, 0
for model in models:
result = predict_full(
model=model,
input=quizzes,
- with_perturbations=True,
+ with_perturbations=with_perturbations,
local_device=local_device,
)
nb_mistakes = (result != quizzes).long().sum(dim=1)
######################################################################
+def remove_old_problematic(c_quizzes, models, nb_to_remove, local_device):
+ nb_removed = 0
+ for input in c_quizzes.split(args.eval_batch_size):
+ _, nb_correct, nb_wrong = evaluate_quizzes(
+ quizzes=input,
+ models=models,
+ with_perturbations=False,
+ local_device=local_device,
+ )
+
+ to_remove = nb_wrong > 0
+ nb_removed += to_remove.long().sum()
+
+ if nb_removed >= nb_to_remove:
+ break
+
+
+######################################################################
+
+
def identity_quizzes(quizzes):
quizzes = quizzes.reshape(quizzes.size(0), 4, -1)
return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values & (
to_keep, nb_correct, nb_wrong = evaluate_quizzes(
quizzes=c_quizzes,
models=models,
- fraction_with_hints=1.0,
+ with_perturbations=True,
local_device=local_device,
)
to_keep, nb_correct, nb_wrong = evaluate_quizzes(
quizzes=c_quizzes,
models=models,
- fraction_with_hints=0,
+ with_perturbations=False,
local_device=local_device,
)