return l[:, 0] < math.log(0.5)
+######################################################################
+
+nb_loaded_models = 0
+
+for model in models:
+ filename = f"gpt_{model.id:03d}.pth"
+
+ try:
+ model.load_state_dict(torch.load(os.path.join(args.result_dir, filename)))
+ log_string(f"model {model.id} successfully loaded from checkpoint.")
+ nb_loaded_models += 1
+
+ except FileNotFoundError:
+ log_string(f"starting model {model.id} from scratch.")
+
+ except:
+ log_string(f"error when loading {filename}.")
+ exit(1)
+
+assert nb_loaded_models == 0 or nb_loaded_models == len(models)
+
######################################################################
for n_epoch in range(args.nb_epochs):
for t in threads:
t.join()
+ for model in weakest_models:
+ filename = f"gpt_{model.id:03d}.pth"
+ torch.save(model.state_dict(), os.path.join(args.result_dir, filename))
+
##################################################
# Replace a fraction of the w_quizzes with fresh ones