nb_threads=args.nb_threads,
tasks=args.grids_science_tasks,
)
- science_w_quizzes = science_problem.generate_w_quizzes(args.nb_test_samples)
+ science_w_quizzes = science_problem.generate_w_quizzes(100)
+
if not args.resume:
- problem.save_some_examples(args.result_dir, "science_")
+ science_problem.save_some_examples(args.result_dir, "science_")
else:
def model_transformer_hot(model):
# model.temperature = args.temperature_hot
- model.set_noise_injection(5.0, ("ffw", args.nb_blocks // 2))
+ model.set_noise_injection(1.0, ("ffw", args.nb_blocks // 2))
def model_transformer_cold(model):
current_epoch = 0
if args.resume:
- try:
- for model in models:
- filename = f"gpt_{model.id:03d}.pth"
-
- try:
- d = torch.load(os.path.join(args.result_dir, filename))
- model.load_state_dict(d[0])
- model.main_test_accuracy = d[1]
- log_string(f"successfully loaded {filename}")
- except FileNotFoundError:
- log_string(f"cannot find {filename}")
- pass
+ for model in models:
+ filename = f"gpt_{model.id:03d}.pth"
try:
- filename = "c_quizzes.pth"
- quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
+ d = torch.load(os.path.join(args.result_dir, filename))
+ model.load_state_dict(d[0])
+ model.main_test_accuracy = d[1]
log_string(f"successfully loaded {filename}")
except FileNotFoundError:
log_string(f"cannot find {filename}")
pass
- try:
- filename = "state.pth"
- state = torch.load(os.path.join(args.result_dir, filename))
- log_string(f"successfully loaded {filename}")
- current_epoch = state["current_epoch"]
- except FileNotFoundError:
- log_string(f"cannot find {filename}")
- pass
+ try:
+ filename = "c_quizzes.pth"
+ quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
+ log_string(f"successfully loaded {filename}")
+ except FileNotFoundError:
+ log_string(f"cannot find {filename}")
+ pass
- except:
- log_string(f"error when loading {filename}.")
- exit(1)
+ try:
+ filename = "state.pth"
+ state = torch.load(os.path.join(args.result_dir, filename))
+ log_string(f"successfully loaded {filename}")
+ current_epoch = state["current_epoch"]
+ except FileNotFoundError:
+ log_string(f"cannot find {filename}")
+ pass
######################################################################
args.nb_new_c_quizzes_for_train = 100
args.nb_new_c_quizzes_for_test = 10
- def compute_valid_quizzes(token_logprobas):
- l = token_logprobas.sum(dim=-1).sort(dim=-1).values
- return torch.rand(l[:, 0].size(), device=l.device) < 0.5
-
-
######################################################################
for n_epoch in range(current_epoch, args.nb_epochs):