return prompts.flatten(1), answers.flatten(1)
- def save_quizzes(
+ def save_quiz_illustrations(
self,
result_dir,
filename_prefix,
for t in self.all_tasks:
print(t.__name__)
prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
- self.save_quizzes(
+ self.save_quiz_illustrations(
result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
)
]:
print(t.__name__)
prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
- grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow)
+ grids.save_quiz_illustrations(
+ "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+ )
exit(0)
predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
- grids.save_quizzes(
+ grids.save_quiz_illustrations(
"/tmp",
"test",
prompts[:nb],
q = new_c_quizzes[:72]
if q.size(0) > 0:
- quiz_machine.save_quizzes(args.result_dir, f"culture_c_quiz_{n_epoch:04d}", q)
+ quiz_machine.save_quiz_illustrations(
+ args.result_dir, f"culture_c_quiz_{n_epoch:04d}", q
+ )
######################################################################
+nb_loaded_models = 0
+
models = []
for k in range(args.nb_gpts):
model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
+ filename = f"gpt_{model.id:03d}.pth"
+
+ try:
+ model.load_state_dict(torch.load(os.path.join(args.result_dir, filename)))
+ log_string(f"model {model.id} successfully loaded from checkpoint.")
+ nb_loaded_models += 1
+
+ except FileNotFoundError:
+ log_string(f"starting model {model.id} from scratch.")
+
+ except:
+ log_string(f"error when loading {filename}.")
+ exit(1)
+
models.append(model)
+assert nb_loaded_models == 0 or nb_loaded_models == len(models)
nb_parameters = sum(p.numel() for p in models[0].parameters())
log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
return l[:, 0] < math.log(0.5)
-######################################################################
-
-nb_loaded_models = 0
-
-for model in models:
- filename = f"gpt_{model.id:03d}.pth"
-
- try:
- model.load_state_dict(torch.load(os.path.join(args.result_dir, filename)))
- log_string(f"model {model.id} successfully loaded from checkpoint.")
- nb_loaded_models += 1
-
- except FileNotFoundError:
- log_string(f"starting model {model.id} from scratch.")
-
- except:
- log_string(f"error when loading {filename}.")
- exit(1)
-
-assert nb_loaded_models == 0 or nb_loaded_models == len(models)
-
######################################################################
for n_epoch in range(args.nb_epochs):
self.train_c_quizzes = []
self.test_c_quizzes = []
- def save_quizzes(
+ def save_quiz_illustrations(
self,
result_dir,
filename_prefix,
predicted_prompts *= 2
predicted_answers *= 2
- self.problem.save_quizzes(
+ self.problem.save_quiz_illustrations(
result_dir,
filename_prefix,
quizzes[:, 1 : 1 + self.prompt_len],
##############################
- self.save_quizzes(
+ self.save_quiz_illustrations(
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
quizzes=test_result[:72],
return prompts, answers
- def save_quizzes(
+ def save_quiz_illustrations(
self,
result_dir,
filename_prefix,
predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
- sky.save_quizzes(
+ sky.save_quiz_illustrations(
"/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
)