desired_average_logits = None
for n_epoch in range(args.nb_epochs):
- log_string(f"--- epoch {n_epoch+1} ----------------------------------------")
+ log_string(f"--- epoch {n_epoch} ----------------------------------------")
a = [(model.id, float(model.main_test_accuracy)) for model in models]
a.sort(key=lambda p: p[0])