parser.add_argument("--nb_epochs", type=int, default=10000)
-parser.add_argument("--batch_size", type=int, default=None)
+parser.add_argument("--batch_size", type=int, default=25)
parser.add_argument("--physical_batch_size", type=int, default=None)
-parser.add_argument("--inference_batch_size", type=int, default=None)
+parser.add_argument("--inference_batch_size", type=int, default=25)
-parser.add_argument("--nb_train_samples", type=int, default=None)
+parser.add_argument("--nb_train_samples", type=int, default=40000)
-parser.add_argument("--nb_test_samples", type=int, default=None)
+parser.add_argument("--nb_test_samples", type=int, default=1000)
parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
parser.add_argument("--schedule_free", action="store_true", default=False)
# ----------------------------------
-parser.add_argument("--model", type=str, default=None)
+parser.add_argument("--model", type=str, default="37M")
parser.add_argument("--dim_model", type=int, default=None)
######################################################################
-default_args = {
- "model": "37M",
- "batch_size": 25,
- "inference_batch_size": 25,
- "nb_train_samples": 40000,
- "nb_test_samples": 1000,
-}
-
-for k, v in default_args.items():
- if getattr(args, k) is None:
- setattr(args, k, v)
-
-######################################################################
-
default_model_args = {
"17K": {
"dim_model": 32,
######################################################################
if args.resume:
- assert os.path.isdir(args.result_dir)
-
+ if not os.path.isdir(args.result_dir):
+ print("Trying to resume with a non-existing result dir {args.result_dir}.")
+ exit(1)
else:
try:
os.mkdir(args.result_dir)
assert args.nb_train_samples % args.batch_size == 0
assert args.nb_test_samples % args.batch_size == 0
+######################################################################
+
problem = grids.Grids(
max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
chunk_size=100,
######################################################################
+# If we need to move an optimizer to a different device
+
def optimizer_to(optim, device):
for param in optim.state.values():