parser.add_argument("--nb_gpts", type=int, default=5)
-parser.add_argument("--min_to_validate", type=int, default=None)
-
-parser.add_argument("--max_to_validate", type=int, default=None)
-
parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
parser.add_argument("--proba_understands", type=float, default=0.99)
args = parser.parse_args()
-if args.min_to_validate is None:
- args.min_to_validate = args.nb_gpts - 1
-
-if args.max_to_validate is None:
- args.max_to_validate = args.nb_gpts - 1
-
if args.result_dir is None:
args.result_dir = f"results_culture"
def one_epoch(model, quiz_machine, local_device=main_device):
- optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
model.to(local_device).train()
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
nb_train_samples, acc_train_loss = 0, 0.0
for input in quiz_machine.batches(model, split="train"):