def one_epoch(model, quiz_machine, local_device=main_device):
model.to(local_device).train()
- optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
nb_train_samples, acc_train_loss = 0, 0.0
hard_w_quizzes = []
input = input.to(local_device)
if nb_train_samples % args.batch_size == 0:
- optimizer.zero_grad()
+ model.optimizer.zero_grad()
targets = input
loss.backward()
if nb_train_samples % args.batch_size == 0:
- optimizer.step()
+ model.optimizer.step()
train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
(("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
(("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
(("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
+ (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold),
# (("f_B", "f_A", "A", "B"), (0, 0, 1, 1), model_transformer_cold),
# (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
]
recorder=recorder,
)
- ##
-
- probas = 0
+ # This is nb_quizzes x nb_models
- for a in range(args.nb_averaging_rounds):
- # This is nb_quizzes x nb_models
-
- seq_logproba = quiz_machine.models_logprobas(
- models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
- ) + quiz_machine.models_logprobas(
- models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
-
- probas += seq_logproba.exp()
+ seq_logproba = quiz_machine.models_logprobas(
+ models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ ) + quiz_machine.models_logprobas(
+ models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ )
- probas /= args.nb_averaging_rounds
+ probas = seq_logproba.exp()
comments = []
nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64)
- to_recycle = None
-
while nb_validated_per_model.sum() < nb_to_validate:
# We use the model that has generated the fewest quizzes to
# balance the number of quizzes per model overall
nb_to_generate_per_iteration,
model_for_generation=model,
procedure=c_quizzes_procedure,
- to_recycle=to_recycle,
)
# We discard the trivial ones, according to a criterion
# specific to the world quizzes (e.g. B=f(B))
- rejected = []
-
to_keep = quiz_machine.problem.trivial(c_quizzes) == False
- if not to_keep.all():
- rejected.append(c_quizzes[to_keep == False])
-
c_quizzes = c_quizzes[to_keep]
- probas = 0
-
- for a in range(args.nb_averaging_rounds):
- # This is nb_quizzes x nb_models
-
- seq_logproba = quiz_machine.models_logprobas(
- models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
- ) + quiz_machine.models_logprobas(
- models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
+ # This is nb_quizzes x nb_models
- probas += seq_logproba.exp()
+ seq_logproba = quiz_machine.models_logprobas(
+ models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ ) + quiz_machine.models_logprobas(
+ models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ )
- probas /= args.nb_averaging_rounds
+ probas = seq_logproba.exp()
nb_succeed = (probas >= args.proba_understands).long().sum(dim=1)
nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1)
& (nb_fail <= args.max_fail_to_validate)
)
- to_recycle = c_quizzes[to_keep == False]
c_quizzes = c_quizzes[to_keep]
if c_quizzes.size(0) > 0:
######################################################################
-
models = []
for k in range(args.nb_gpts):
dropout=args.dropout,
).to(main_device)
- model.main_test_accuracy = 0.0
model.id = k
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+ model.main_test_accuracy = 0.0
+
model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes(
args.nb_train_samples
)
try:
d = torch.load(os.path.join(args.result_dir, filename))
- model.load_state_dict(d[0])
- model.main_test_accuracy = d[1]
+ model.load_state_dict(d["state_dict"])
+ model.optimizer.load_state_dict(d["optimizer_state_dict"])
+ model.main_test_accuracy = d["main_test_accuracy"]
log_string(f"successfully loaded {filename}")
except FileNotFoundError:
log_string(f"cannot find {filename}")
for model in weakest_models:
filename = f"gpt_{model.id:03d}.pth"
torch.save(
- (model.state_dict(), model.main_test_accuracy),
+ {
+ "state_dict": model.state_dict(),
+ "optimizer_state_dict": model.optimizer.state_dict(),
+ "main_test_accuracy": model.main_test_accuracy,
+ },
os.path.join(args.result_dir, filename),
)
log_string(f"wrote {filename}")
self.answer_len = None
self.prompt_noise = prompt_noise
+ # struct, mask_generate, mask_noise
self.understood_structures = [
(("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)),
(("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)),
quizzes, from_w = quizzes[i], from_w[i]
self.randomize_configuations_inplace(
- quizzes, structs=[s for s, m, _ in self.understood_structures]
+ quizzes, structs=[s for s, _, _ in self.understood_structures]
)
if self.prompt_noise > 0.0:
- for struct, mask, noise_mask in self.understood_structures:
+ for struct, _, mask_noise in self.understood_structures:
i = self.problem.indices_select(quizzes=quizzes, struct=struct)
if i.any():
quizzes[i] = self.problem.inject_noise(
- quizzes[i], self.prompt_noise, struct=struct, mask=noise_mask
+ quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise
)
return quizzes, from_w
nb = 0
# We consider all the configurations that we train for
- for struct, mask, _ in self.understood_structures:
+ for struct, mask_generate, _ in self.understood_structures:
i = self.problem.indices_select(quizzes=input, struct=struct)
nb += i.long().sum()
result[i], correct[i] = self.predict(
- model=model, quizzes=input[i], struct=struct, mask=mask
+ model=model, quizzes=input[i], struct=struct, mask=mask_generate
)
- predicted_parts[i] = torch.tensor(mask, device=self.device)[None, :]
+ predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[
+ None, :
+ ]
solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1
correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long()
models_for_validation,
c_quizzes,
struct,
- mask,
- noise_mask=None,
+ mask_value,
+ mask_noise=None,
device=None,
):
if device is None:
device=device,
)
- if self.prompt_noise > 0.0 and noise_mask is not None:
- c_quizzes = self.problem.inject_noise(
- c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask
- )
+ # if self.prompt_noise > 0.0 and mask_noise is not None:
+ # c_quizzes = self.problem.inject_noise(
+ # c_quizzes, self.prompt_noise, struct=struct, mask=mask_noise
+ # )
for model in models_for_validation:
with torch.autograd.no_grad():
seq_logproba.split(self.batch_size),
):
input = input.to(device)
- ar_mask = self.make_ar_mask(input, struct=struct, mask=mask)
+ ar_mask = self.make_ar_mask(input, struct=struct, mask=mask_value)
output = model(mygpt.BracketedSequence(input)).x
l[:, model.id] = (
-F.cross_entropy(
######################################################################
- def generate_c_quizzes(
- self, nb, model_for_generation, procedure, to_recycle=None, recorder=None
- ):
+ def generate_c_quizzes(self, nb, model_for_generation, procedure, recorder=None):
seq_logproba = torch.zeros(nb, device=self.device)
c_quizzes = None
self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B"))
)
- if to_recycle is not None and to_recycle.size(0) > 0:
- to_recycle = self.problem.reconfigure(to_recycle, s)
- c_quizzes[: to_recycle.size(0)] = to_recycle
-
- to_recycle = None
-
c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
return c_quizzes.to("cpu")