parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
-parser.add_argument("--prompt_noise", type=float, default=0.05)
+parser.add_argument("--prompt_noise_proba", type=float, default=0.05)
-parser.add_argument("--nb_hints", type=int, default=25)
+parser.add_argument("--hint_proba", type=float, default=0.01)
+
+# parser.add_argument("--nb_hints", type=int, default=25)
parser.add_argument("--nb_runs", type=int, default=1)
def add_hints(imt_set):
input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
- h = torch.rand(masks.size(), device=masks.device) - masks
- t = h.sort(dim=1).values[:, args.nb_hints, None]
- mask_hints = (h < t).long()
+ # h = torch.rand(masks.size(), device=masks.device) - masks
+ # t = h.sort(dim=1).values[:, args.nb_hints, None]
+ # mask_hints = (h < t).long()
+ mask_hints = (
+ torch.rand(input.size(), device=input.device) < args.hint_proba
+ ).long() * masks
masks = (1 - mask_hints) * masks
input = (1 - mask_hints) * input + mask_hints * targets
return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
# Make pixels from the available input (mask=0) noise with probability
-# args.prompt_noise
+# args.prompt_noise_proba
def add_noise(imt_set):
input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
noise = quiz_machine.pure_noise(input.size(0), input.device)
change = (1 - masks) * (
- torch.rand(input.size(), device=input.device) < args.prompt_noise
+ torch.rand(input.size(), device=input.device) < args.prompt_noise_proba
).long()
input = (1 - change) * input + change * noise
return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
record = []
- src = imt_set.split(args.train_batch_size)
+ src = imt_set.split(args.eval_batch_size)
if desc is not None:
src = tqdm.tqdm(
src,
dynamic_ncols=True,
desc=desc,
- total=imt_set.size(0) // args.train_batch_size,
+ total=imt_set.size(0) // args.eval_batch_size,
)
for imt in src:
sub_changed = all_changed[all_changed].clone()
src = zip(
- sub_input.split(args.train_batch_size),
- sub_masks.split(args.train_batch_size),
- sub_changed.split(args.train_batch_size),
+ sub_input.split(args.eval_batch_size),
+ sub_masks.split(args.eval_batch_size),
+ sub_changed.split(args.eval_batch_size),
)
for input, masks, changed in src:
label = "train"
model.train().to(local_device)
optimizer_to(model.optimizer, local_device)
+ batch_size = args.train_batch_size
else:
label = "test"
model.eval().to(local_device)
+ batch_size = args.eval_batch_size
nb_samples, acc_loss = 0, 0.0
for imt in tqdm.tqdm(
- imt_set.split(args.train_batch_size),
+ imt_set.split(batch_size),
dynamic_ncols=True,
desc=label,
- total=quizzes.size(0) // args.train_batch_size,
+ total=quizzes.size(0) // batch_size,
):
input, masks, targets = imt[:, 0], imt[:, 1], imt[:, 2]
if train and nb_samples % args.batch_size == 0:
generator_id = model.id
c_quizzes = ae_generate(
- model=model, nb=args.train_batch_size * 10, local_device=local_device
+ model=model, nb=args.eval_batch_size * 10, local_device=local_device
)
# Select the ones that are solved properly by some models and