parser.add_argument("--batch_size", type=int, default=25)
-parser.add_argument("--physical_batch_size", type=int, default=None)
+parser.add_argument("--train_batch_size", type=int, default=None)
-parser.add_argument("--inference_batch_size", type=int, default=25)
+parser.add_argument("--eval_batch_size", type=int, default=25)
parser.add_argument("--nb_train_samples", type=int, default=50000)
assert len(gpus) == 0
main_device = torch.device("cpu")
-if args.physical_batch_size is None:
- args.physical_batch_size = args.batch_size
+if args.train_batch_size is None:
+ args.train_batch_size = args.batch_size
else:
- assert args.batch_size % args.physical_batch_size == 0
+ assert args.batch_size % args.train_batch_size == 0
assert args.nb_train_samples % args.batch_size == 0
assert args.nb_test_samples % args.batch_size == 0
alien_quiz_machine = quiz_machine.QuizMachine(
problem=alien_problem,
- batch_size=args.inference_batch_size,
+ batch_size=args.eval_batch_size,
result_dir=args.result_dir,
logger=log_string,
device=main_device,
quiz_machine = quiz_machine.QuizMachine(
problem=problem,
- batch_size=args.inference_batch_size,
+ batch_size=args.eval_batch_size,
result_dir=args.result_dir,
logger=log_string,
device=main_device,
record = []
- src = imt_set.split(args.physical_batch_size)
+ src = imt_set.split(args.train_batch_size)
if desc is not None:
src = tqdm.tqdm(
src,
dynamic_ncols=True,
desc=desc,
- total=imt_set.size(0) // args.physical_batch_size,
+ total=imt_set.size(0) // args.train_batch_size,
)
for imt in src:
all_changed = torch.full((all_input.size(0),), True, device=all_input.device)
for it in range(args.diffusion_nb_iterations):
+ log_string(f"nb_changed {all_changed.long().sum().item()}")
+
if not all_changed.any():
break
sub_changed = all_changed[all_changed].clone()
src = zip(
- sub_input.split(args.physical_batch_size),
- sub_masks.split(args.physical_batch_size),
- sub_changed.split(args.physical_batch_size),
+ sub_input.split(args.train_batch_size),
+ sub_masks.split(args.train_batch_size),
+ sub_changed.split(args.train_batch_size),
)
for input, masks, changed in src:
nb_samples, acc_loss = 0, 0.0
for imt in tqdm.tqdm(
- imt_set.split(args.physical_batch_size),
+ imt_set.split(args.train_batch_size),
dynamic_ncols=True,
desc=label,
- total=quizzes.size(0) // args.physical_batch_size,
+ total=quizzes.size(0) // args.train_batch_size,
):
input, masks, targets = imt[:, 0], imt[:, 1], imt[:, 2]
if train and nb_samples % args.batch_size == 0:
generator_id = model.id
c_quizzes = ae_generate(
- model=model, nb=args.physical_batch_size * 10, local_device=local_device
+ model=model, nb=args.train_batch_size * 10, local_device=local_device
)
# Select the ones that are solved properly by some models and