if shape == "fwd_3_bck_123":
forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
- backward_mask = ((T % (S + 1) != 0) & (T >= S + 1)).long()
+ backward_mask = ((T % (S + 1) != 0) & (T >= 1 * (S + 1))).long()
elif shape == "fwd_012_bck_0":
forward_mask = ((T % (S + 1) != 0) & (T < 3 * (S + 1))).long()
- backward_mask = ((T % (S + 1) != 0) & (T < S + 1)).long()
+ backward_mask = ((T % (S + 1) != 0) & (T < 1 * (S + 1))).long()
elif shape == "fwd_3_bck_3":
forward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
backward_mask = ((T % (S + 1) != 0) & (T >= 3 * (S + 1))).long()
S = self.height * self.width
Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1]
f_Bs = answers[:, 1:]
- print(f"{prompts.size()=} {answers.size()=} {Bs.size()=} {f_Bs.size()=}")
return (Bs == f_Bs).long().min(dim=-1).values > 0
def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
nb, nrow = 8, 2
# nb, nrow = 8, 2
- for t in grids.all_tasks:
- # for t in [grids.task_compute]:
+ # for t in grids.all_tasks:
+ for t in [grids.task_convex]:
print(t.__name__)
prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
# prompts[...] = torch.randint(grids.nb_token_values(), prompts.size())
######################################################################
+current_epoch = 0
+
if args.resume:
try:
for model in models:
log_string(f"cannot find {filename}")
pass
+ try:
+ filename = "state.pth"
+ state = torch.load(os.path.join(args.result_dir, filename))
+ log_string(f"successfully loaded {filename}")
+ current_epoch = state["current_epoch"]
+ except FileNotFoundError:
+ log_string(f"cannot find {filename}")
+ pass
+
except:
log_string(f"error when loading {filename}.")
exit(1)
######################################################################
-for n_epoch in range(args.nb_epochs):
+for n_epoch in range(current_epoch, args.nb_epochs):
log_string(f"--- epoch {n_epoch} ----------------------------------------")
cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
)
log_string(f"wrote {filename}")
+ state = {"current_epoch": n_epoch}
+ filename = "state.pth"
+ torch.save(state, os.path.join(args.result_dir, filename))
+ log_string(f"wrote {filename}")
+
# Renew the training samples
for model in weakest_models:
###############################################################
- def solution_nb_correct(
- self, models_for_validation, c_quizzes, bidirectional_validation=True
- ):
+ def solution_nb_correct(self, models_for_validation, c_quizzes):
seq_logproba = torch.zeros(
c_quizzes.size(0),
max([m.id for m in models_for_validation]) + 1,
)
nb_correct = 0
+ correct_models = torch.empty(
+ c_quizzes.size(0),
+ max([m.id for m in models_for_validation]) + 1,
+ device=self.device,
+ dtype=torch.int64,
+ )
seq_logproba[...] = 0.0
device=self.device,
)
- correct = (c_quizzes == result).long().min(dim=-1).values
+ correct_models[:, model.id] = (
+ (c_quizzes == result).long().min(dim=-1).values
+ )
# -------------------------------
device=self.device,
)
- flipped_correct = (c_quizzes == result).long().min(dim=-1).values
+ correct_models[:, model.id] *= (
+ (c_quizzes == result).long().min(dim=-1).values
+ )
# -------------------------------
- nb_correct += correct * flipped_correct
+ i = correct_models.sum(dim=1) == correct_models.size(1) - 1
+ c = (correct_models[i] == 0).long().sum(dim=0)
+ self.logger(f"nb_failures_on_validated {tuple(x.item() for x in c)}")
- return nb_correct.to("cpu")
+ return correct_models.sum(dim=1).to("cpu")
###############################################################