######################################################################
+# Make args.nb_hints holes in the mask and copy the corresponding cell
+# values from the target to the input
-def add_hints_(imt_set):
- input, masks, targets = imt_set
+
+def add_hints(imt_set):
+ input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
h = torch.rand(masks.size(), device=masks.device) - masks
t = h.sort(dim=1).values[:, args.nb_hints, None]
mask_hints = (h < t).long()
- masks[...] = (1 - mask_hints) * masks
- input[...] = (1 - mask_hints) * input + mask_hints * targets
-
-
-def add_hints(masks, fraction_with_hints):
- if fraction_with_hints > 0:
- h = torch.rand(masks.size(), device=masks.device) - masks
- t = h.sort(dim=1).values[:, args.nb_hints, None]
- mask_hints = (h < t).long()
- return (1 - mask_hints) * masks
- else:
- return masks
+ masks = (1 - mask_hints) * masks
+ input = (1 - mask_hints) * input + mask_hints * targets
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
# IMT for input / masks / target
return torch.cat(record)
-def predict_full(model, input, local_device=main_device):
+def predict_full(model, input, with_hints=False, local_device=main_device):
input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
nb = input.size(0)
masks = input.new_zeros(input.size())
input = (1 - masks) * targets
imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+ if with_hints:
+ imt_set = add_hints(imt_set)
+
result = ae_predict(model, imt_set, local_device=local_device, desc=None)
result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
args.c_quiz_multiplier,
)
- q1, q2 = quizzes.to(local_device).chunk(2)
-
- imt_set = torch.cat(
- [
- batch_for_prediction_imt(q1),
- batch_for_generation_imt(q2),
- ]
- )
-
+ q_p, q_g = quizzes.to(local_device).chunk(2)
+ b_p = batch_for_prediction_imt(q_p)
+ i = torch.rand(b_p.size(0)) < 0.5
+ b_p[i] = add_hints(b_p[i])
+ b_g = batch_for_generation_imt(q_g)
+ imt_set = torch.cat([b_p, b_g])
imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
if train:
result = predict_full(
model=model,
input=quizzes,
- fraction_with_hints=fraction_with_hints,
+ with_hints=True,
local_device=local_device,
)
nb_mistakes = (result != quizzes).long().sum(dim=1)