######################################################################
+class vanilla_attention(q, k, v):
+ a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+ a = a.softmax(dim=3)
+ y = torch.einsum("nhts,nhsd->nhtd", a, v)
+
+ # y = flex_attention(q, k, v, score_mod=noop)
+
+ y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
+
+ return y
+
+
+vanilla_attention = torch.compille(vanilla_attention)
+
+
class MHAttention(nn.Module):
def __init__(
self,
x_kv = x_q
q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
- k = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_k)
- v = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_v)
+ k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
+ v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
+
+ def noop(score, b, h, q_idx, kv_idx):
+ return score
- y = flex_attention(q, k, v)
+ y = vanilla_attention(q, k, v, score_mod=noop)
+ # y = flex_attention(q, k, v, score_mod=noop)
y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
import ffutils
-import attae
-
import mygpt
import sky, grids, quiz_machine
######################################################################
-def model_ae_proba_solutions(model, input, log_proba=False):
+def model_ae_proba_solutions(model, input, log_probas=False, reduce=True):
record = []
for x_0 in input.split(args.batch_size):
loss_per_token = F.cross_entropy(
logits.transpose(1, 2), x_0, reduction="none"
)
- loss += (loss_per_token * mask_generate).sum(dim=1)
+ if reduce:
+ loss += (loss_per_token * mask_generate).sum(dim=1)
+ else:
+ loss += loss_per_token * mask_generate
+
record.append(loss)
loss = torch.cat(record, dim=0)
- if log_proba:
+ if log_probas:
return -loss
else:
return (-loss).exp()
mask_generate = quiz_machine.make_quiz_mask(
quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
+
logits = logits_hat_x_0_from_random_iteration(
model, x_0, mask_generate, prompt_noise=args.prompt_noise
)
######################################################################
+# import attae
+
models = []
for i in range(args.nb_models):
- # model = MyAttentionAE(
- model = attae.AttentionAE(
+ model = MyAttentionAE(
+ # model = attae.AttentionAE(
vocabulary_size=vocabulary_size,
dim_model=args.dim_model,
dim_keys=args.dim_keys,
######################################################################
-def quiz_validation(models, c_quizzes, local_device):
+def quiz_validation_1(models, c_quizzes, local_device):
nb_have_to_be_correct = args.nb_models // 2
- nb_have_to_be_wrong = args.nb_models // 5
+ nb_have_to_be_wrong = 1
- nb_runs = 3
+ nb_runs = 1
nb_mistakes_to_be_wrong = 5
record_wrong = []
return to_keep, wrong
+def quiz_validation_2(models, c_quizzes, local_device):
+ nb_have_to_be_correct = 3
+ nb_have_to_be_wrong = 1
+ nb_runs = 3
+
+ record_wrong = []
+ nb_correct, nb_wrong = 0, 0
+
+ for i, model in enumerate(models):
+ assert i == model.id # a bit of paranoia
+ model = copy.deepcopy(model).to(local_device).eval()
+ log_probas_max, log_probas_min = None, None
+ for _ in range(nb_runs):
+ log_probas = model_ae_proba_solutions(
+ model, c_quizzes, log_probas=True, reduce=False
+ )
+ log_probas_max = (
+ log_probas
+ if log_probas_max is None
+ else log_probas.maximum(log_probas_max)
+ )
+ log_probas_min = (
+ log_probas
+ if log_probas_min is None
+ else log_probas.minimum(log_probas_min)
+ )
+ probas = log_probas.sum(dim=1).exp()
+ correct = (log_probas_min.exp() <= 0.75).long().sum(dim=1) == 0
+ wrong = (log_probas_min.exp() <= 0.1).long().sum(dim=1) >= 3
+ record_wrong.append(wrong[:, None])
+ nb_correct += correct.long()
+ nb_wrong += wrong.long()
+
+ to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
+
+ wrong = torch.cat(record_wrong, dim=1)
+
+ return to_keep, wrong
+
+
+def quiz_validation(models, c_quizzes, local_device):
+ nb_have_to_be_correct = 3
+ nb_have_to_be_wrong = 1
+ nb_runs = 3
+
+ record_wrong = []
+ nb_correct, nb_wrong = 0, 0
+
+ for i, model in enumerate(models):
+ assert i == model.id # a bit of paranoia
+ model = copy.deepcopy(model).to(local_device).eval()
+ log_probas = 0
+ for _ in range(nb_runs):
+ log_probas += model_ae_proba_solutions(
+ model, c_quizzes, log_probas=True, reduce=False
+ )
+ probas = log_probas.exp()
+ correct = (probas <= 0.75).long().sum(dim=1) == 0
+ wrong = ((probas <= 0.125).long().sum(dim=1) >= 5) & (
+ log_probas.sum(dim=1).div(nb_runs).exp() <= 0.5
+ )
+ record_wrong.append(wrong[:, None])
+ nb_correct += correct.long()
+ nb_wrong += wrong.long()
+
+ to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
+
+ wrong = torch.cat(record_wrong, dim=1)
+
+ return to_keep, wrong
+
+
def generate_ae_c_quizzes(models, nb, local_device=main_device):
# To be thread-safe we must make copies
start_time = time.perf_counter()
- for gpu in gpus:
- t = threading.Thread(
- target=thread_generate_ae_c_quizzes,
- daemon=True,
- args=(models, nb_c_quizzes_to_generate, records, gpu),
- )
+ if len(gpus) > 1:
+ for gpu in gpus:
+ t = threading.Thread(
+ target=thread_generate_ae_c_quizzes,
+ daemon=True,
+ args=(models, nb_c_quizzes_to_generate, records, gpu),
+ )
- # To get a different sequence between threads
- log_string(f"dummy {torch.rand(1)}")
- threads.append(t)
- t.start()
+ # To get a different sequence between threads
+ log_string(f"dummy {torch.rand(1)}")
+ threads.append(t)
+ t.start()
- for t in threads:
- t.join()
+ for t in threads:
+ t.join()
+
+ else:
+ records.append(
+ generate_ae_c_quizzes(
+ models, nb_c_quizzes_to_generate, records, gpus[0]
+ )
+ )
time_c_quizzes = int(time.perf_counter() - start_time)
start_time = time.perf_counter()
- for gpu, model in zip(gpus, weakest_models):
+ if len(gpus) > 1:
+ for gpu, model in zip(gpus, weakest_models):
+ log_string(f"training model {model.id} (accuracy {model.test_accuracy})")
+ if c_quizzes is None:
+ c_quizzes_for_this_model = None
+ else:
+ c_quizzes_for_this_model = c_quizzes[agreements[:, model.id]]
+
+ t = threading.Thread(
+ target=one_ae_epoch,
+ daemon=True,
+ args=(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpu),
+ )
+
+ threads.append(t)
+
+ t.start()
+
+ for t in threads:
+ t.join()
+
+ else:
+ model = weakest_models[0]
log_string(f"training model {model.id} (accuracy {model.test_accuracy})")
if c_quizzes is None:
c_quizzes_for_this_model = None
else:
c_quizzes_for_this_model = c_quizzes[agreements[:, model.id]]
- t = threading.Thread(
- target=one_ae_epoch,
- daemon=True,
- args=(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpu),
- )
-
- threads.append(t)
-
- t.start()
-
- for t in threads:
- t.join()
+ one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpus[0])
time_train += int(time.perf_counter() - start_time)