# background_gray=240
# dots = False
- # grid_gray = 200
+ # grid_gray = 192
# thickness = 0
- # background_gray = 240
+ # background_gray = 255
# dots = True
named_colors = [
######################################################################
def vocabulary_size(self):
- return self.nb_colors
+ warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning)
+ return self.nb_colors + 4
def grid2img(self, x, scale=15, grids=True):
m = torch.logical_and(x >= 0, x < self.nb_colors).long()
:,
:,
:,
- scale // 2 - 2 : scale // 2 + 1,
+ scale // 2 - 1 : scale // 2 + 2,
:,
- scale // 2 - 2 : scale // 2 + 1,
+ scale // 2 - 1 : scale // 2 + 2,
]
- z[...] = (z == self.background_gray) * self.grid_gray + (
- z != self.background_gray
- ) * z
+ zz = (z == self.background_gray).min(dim=1, keepdim=True).values
+ z[...] = zz * self.grid_gray + (zz == False) * z
for n in range(m.size(0)):
for i in range(m.size(1)):
comment_height=48,
nrow=4,
grids=True,
- margin=8,
+ margin=12,
delta=False,
):
quizzes = quizzes.to("cpu")
+ (1 - predicted_parts[:, :, None]) * white[None, None, :]
)
- img_A = self.add_frame(img_A, colors[:, 0], thickness=8)
- img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=8)
- img_B = self.add_frame(img_B, colors[:, 2], thickness=8)
- img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=8)
+ separation = 6
+
+ img_A = self.add_frame(img_A, colors[:, 0], thickness=separation)
+ img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=separation)
+ img_B = self.add_frame(img_B, colors[:, 2], thickness=separation)
+ img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=separation)
img_A = self.add_frame(img_A, white[None, :], thickness=2)
img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2)
img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2)
if delta:
- img_delta_A = self.add_frame(img_delta_A, colors[:, 0], thickness=8)
+ img_delta_A = self.add_frame(
+ img_delta_A, colors[:, 0], thickness=separation
+ )
img_delta_A = self.add_frame(img_delta_A, white[None, :], thickness=2)
- img_delta_B = self.add_frame(img_delta_B, colors[:, 0], thickness=8)
+ img_delta_B = self.add_frame(
+ img_delta_B, colors[:, 0], thickness=separation
+ )
img_delta_B = self.add_frame(img_delta_B, white[None, :], thickness=2)
img = torch.cat(
[img_A, img_f_A, img_delta_A, img_B, img_f_B, img_delta_B], dim=3
c_quizzes = c_quizzes[i]
w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
- w_quizzes = w_quizzes.view(w_quizzes.size(0), 4, -1)[:, :, 1:].reshape(
- w_quizzes.size(0), -1
- )
+
quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
nb_w_quizzes = w_quizzes.size(0)
nb_c_quizzes = c_quizzes.size(0)
return torch.cat(record)
-def predict_full(model, input, with_perturbations=False, local_device=main_device):
+def predict_full(
+ model, input, with_noise=False, with_hints=False, local_device=main_device
+):
input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
nb = input.size(0)
masks = input.new_zeros(input.size())
input = (1 - masks) * targets
imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
- if with_perturbations:
+ if with_hints:
imt_set = add_hints_imt(imt_set)
+
+ if with_noise:
imt_set = add_noise_imt(imt_set)
result = ae_predict(model, imt_set, local_device=local_device, desc=None)
problem.save_quizzes_as_image(
args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes
)
- result = predict_full(model=model, input=quizzes, local_device=local_device)
+ result = predict_full(
+ model=model,
+ input=quizzes,
+ with_noise=True,
+ with_hints=True,
+ local_device=local_device,
+ )
problem.save_quizzes_as_image(
args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result
)
result = predict_full(
model=model,
input=quizzes,
- with_perturbations=True,
+ with_noise=False,
+ with_hints=True,
local_device=local_device,
)
- nb_correct += (max_nb_mistakes_on_one_grid(quizzes, result) == 0).long()
+ nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, result)
+ nb_correct += (nb_mistakes == 0).long()
- result = predict_full(
- model=model,
- input=quizzes,
- with_perturbations=False,
- local_device=local_device,
- )
+ # result = predict_full(
+ # model=model,
+ # input=quizzes,
+ # with_noise=False,
+ # with_hints=False,
+ # local_device=local_device,
+ # )
- nb_wrong += (
- max_nb_mistakes_on_one_grid(quizzes, result) >= args.nb_mistakes_to_be_wrong
- ).long()
+ nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long()
to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
nb_wrong >= args.nb_have_to_be_wrong
)
+ # print("\n\n", nb_correct, nb_wrong)
+
return to_keep, nb_correct, nb_wrong
def identity_quizzes(quizzes):
quizzes = quizzes.reshape(quizzes.size(0), 4, -1)
- return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values & (
+ return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values | (
quizzes[:, 2] == quizzes[:, 3]
).min(dim=1).values