self.check_structure(quizzes, struct)
return struct
- def inject_noise(self, quizzes, noise, struct, mask):
+ def inject_noise(self, quizzes, noise, struct, quad):
assert self.check_structure(quizzes, struct=struct)
S = self.height * self.width
- mask = torch.tensor(mask, device=quizzes.device)
+ mask = torch.tensor(quad, device=quizzes.device)
mask = mask[None, :, None].expand(1, 4, S + 1).clone()
mask[:, :, 0] = 0
mask = mask.reshape(1, -1).expand_as(quizzes)
).values
def make_quiz_mask(
- self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)
+ self, quizzes, struct=("A", "f_A", "B", "f_B"), quad=(0, 0, 0, 1)
):
assert self.check_structure(quizzes, struct)
- ar_mask = quizzes.new_zeros(quizzes.size())
+ mask_ar = quizzes.new_zeros(quizzes.size())
S = self.height * self.width
- a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:]
- a[:, 0, :] = mask[0]
- a[:, 1, :] = mask[1]
- a[:, 2, :] = mask[2]
- a[:, 3, :] = mask[3]
+ a = mask_ar.reshape(mask_ar.size(0), 4, S + 1)[:, :, 1:]
+ a[:, 0, :] = quad[0]
+ a[:, 1, :] = quad[1]
+ a[:, 2, :] = quad[2]
+ a[:, 3, :] = quad[3]
- return ar_mask
+ return mask_ar
def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")):
S = self.height * self.width
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
+ torch.set_float32_matmul_precision("high")
######################################################################
######################################################################
+def mask_ar_to_ranks(mask_ar):
+ a = (mask_ar < 2).long()
+ a = a.cumsum(dim=1) - a
+ b = ((mask_ar[:, :-1] == 2) & (mask_ar[:, 1:] != 2)).long().cumsum(dim=1)
+ a[:, 1:] += b
+ return a
+
+
def run_tests(model, quiz_machine, local_device=main_device):
with torch.autograd.no_grad():
model.to(local_device).eval()
nb_test_samples, acc_test_loss = 0, 0.0
nb_samples_accumulated = 0
- full_input, full_mask_loss = quiz_machine.data_input(
+ full_input, full_mask_ar, full_mask_loss = quiz_machine.data_input(
args.nb_test_samples, test_c_quiz_bags
)
src = zip(
- full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
+ full_input.split(args.batch_size),
+ full_mask_ar.split(args.batch_size),
+ full_mask_loss.split(args.batch_size),
)
- for input, mask_loss in tqdm.tqdm(
+ for input, mask_ar, mask_loss in tqdm.tqdm(
src,
dynamic_ncols=True,
desc="test",
total=full_input.size(0) // args.batch_size,
):
input = input.to(local_device)
+ mask_ar = mask_ar.to(local_device)
mask_loss = mask_loss.to(local_device)
targets = input
- output = model(mygpt.BracketedSequence(input)).x
+ output = model(
+ mygpt.BracketedSequence(input, ranks=mask_ar_to_ranks(mask_ar))
+ ).x
loss_per_token = F.cross_entropy(
output.transpose(1, 2), targets, reduction="none"
)
log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
- input, _ = quiz_machine.data_input(1000, test_c_quiz_bags)
+ input, _, _ = quiz_machine.data_input(1000, test_c_quiz_bags)
model.test_accuracy = quiz_machine.produce_results(
n_epoch=n_epoch,
nb_train_samples, acc_train_loss = 0, 0.0
- full_input, full_mask_loss = quiz_machine.data_input(
+ full_input, full_mask_ar, full_mask_loss = quiz_machine.data_input(
args.nb_train_samples, train_c_quiz_bags
)
- src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
- for input, mask_loss in tqdm.tqdm(
+ src = zip(
+ full_input.split(args.batch_size),
+ full_mask_ar.split(args.batch_size),
+ full_mask_loss.split(args.batch_size),
+ )
+
+ for input, mask_ar, mask_loss in tqdm.tqdm(
src,
dynamic_ncols=True,
desc="training",
total=full_input.size(0) // args.batch_size,
):
input = input.to(local_device)
+ mask_ar = mask_ar.to(local_device)
mask_loss = mask_loss.to(local_device)
if nb_train_samples % args.batch_size == 0:
targets = input
- output = model(mygpt.BracketedSequence(input)).x
+ output = model(
+ mygpt.BracketedSequence(input, ranks=mask_ar_to_ranks(mask_ar))
+ ).x
loss_per_token = F.cross_entropy(
output.transpose(1, 2), targets, reduction="none"
c_quizzes_procedure = [
- # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot),
(("f_B", "f_A", "A", "B"), (1, 1, 1, 1), model_modifier_hot),
- # (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), model_modifier_hot),
- # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
+ # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot),
+ # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold),
+ # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
]
######################################################################
procedure=c_quizzes_procedure,
)
+ log_string(f"nb_generated_quizzes {c_quizzes.size(0)}")
+
nb_generated += c_quizzes.size(0)
# We discard the trivial ones, according to a criterion
c_quizzes = c_quizzes[to_keep]
+ log_string(f"nb_non_trivial_quizzes {c_quizzes.size(0)}")
+
# Keep only the quizzes that the main model cannot solve
solved_c_quizzes = c_quizzes.clone()
mask=(0, 0, 0, 1),
)
+ log_string(f"nb_generated_quizzes {c_quizzes.size(0)}")
+
main_probas = model_proba_solutions(main_model, main_solution)
- log_string(f"main_probas {main_probas}")
+ # log_string(f"main_probas {main_probas}")
keep = main_probas < args.proba_not_understands
c_quizzes = c_quizzes[keep]
+ log_string(f"nb_not_understood_quizzes {c_quizzes.size(0)}")
+
# If there are some quizzes that the main model cannot solve,
# pick the most confident solution
)
probas = model_proba_solutions(model, solution)
- log_string(f"probas {probas}")
+ # log_string(f"probas {probas}")
keep = probas >= c_quizzes_proba
c_quizzes = solution[keep]
c_quizzes_proba[keep] = probas[keep]
keep = c_quizzes_proba >= args.proba_understands
- recorded.append(c_quizzes_proba[keep])
- nb_validated += keep.long().sum()
+ c_quizzes = c_quizzes[keep]
+
+ log_string(f"nb_kept {c_quizzes.size(0)} total nb_validated {nb_validated}")
+ recorded.append(c_quizzes.clone().to("cpu"))
+ nb_validated += c_quizzes.size(0)
duration = time.perf_counter() - start_time
dim_hidden=args.dim_hidden,
nb_heads=args.nb_heads,
nb_blocks=args.nb_blocks,
- compute_attzero=compute_causal_attzero,
dropout=args.dropout,
).to(main_device)
dim_hidden=args.dim_hidden,
nb_heads=args.nb_heads,
nb_blocks=args.nb_blocks,
- compute_attzero=compute_causal_attzero,
dropout=args.dropout,
).to(main_device)
model.load_state_dict(new_model.state_dict())
class BracketedSequence:
- def __init__(self, x, first=None, nb=None):
+ def __init__(self, x, first=None, nb=None, ranks=None):
self.x = x
self.first = 0 if first is None else first
self.nb = x.size(1) if nb is None else nb
+ self.ranks = ranks
def slice(self):
return self.x[:, self.first : self.first + self.nb]
else:
self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
- return BracketedSequence(self.cache_y, bs.first, bs.nb)
+ return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.ranks)
##############################
self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
def forward(self, bs):
- return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
+ return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.ranks)
##############################
bs.slice() + self.pe[bs.first : bs.first + bs.nb]
)
- return BracketedSequence(self.cache_y, bs.first, bs.nb)
+ return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.ranks)
##############################
dim_qk,
dim_v,
nb_heads=1,
- compute_attzero=None,
attention_dropout=0.0,
):
super().__init__()
def randw(*d):
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
- self.compute_attzero = compute_attzero
self.attention_dropout = attention_dropout
self.record_attention = False
"nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_kv.first + bs_kv.nb]
) / math.sqrt(self.w_q.size(1))
- if self.compute_attzero is not None:
- if bs_q.first == 0:
- self.cache_attzero = self.compute_attzero(
- torch.arange(x_q.size(1), device=q.device)[:, None],
- torch.arange(x_kv.size(1), device=q.device)[None, :],
- )[None, None, :, :]
+ t = torch.arange(x_q.size(1), device=a.device)
+
+ if bs_q.ranks is not None:
a = a.masked_fill(
- self.cache_attzero[
- :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_kv.first + bs_kv.nb
- ],
+ (
+ bs_q.ranks[:, None, bs_q.first : bs_q.first + bs_q.nb, None]
+ <= bs_kv.ranks[:, None, None, : bs_kv.first + bs_kv.nb]
+ )
+ & (
+ t[None, None, bs_q.first : bs_q.first + bs_q.nb, None]
+ != t[None, None, None, : bs_kv.first + bs_kv.nb]
+ ),
float("-inf"),
)
dim_qk=dim_keys,
dim_v=dim_model // nb_heads,
nb_heads=nb_heads,
- compute_attzero=compute_attzero,
attention_dropout=dropout,
)
super().__init__()
def forward(self, bs):
- return BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+ return BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.ranks)
class MyGPT(nn.Module):
dim_hidden,
nb_heads,
nb_blocks,
- compute_attzero=None,
dropout=0.0,
len_max=1e5,
):
dim_qk=dim_keys,
dim_v=dim_model // nb_heads,
nb_heads=nb_heads,
- compute_attzero=compute_attzero,
attention_dropout=dropout,
),
),
######################################################################
-# ar_mask is a tensor with 0s and 1s, of same shape as input, with
+# mask_ar is a tensor with 0s and 1s, of same shape as input, with
# 1s where tokens should be generated. The others are kept
# unchanged.
def one_batch_masked_inplace_autoregression(
model,
input,
- ar_mask,
+ mask_ar,
acc_seq_logprobas,
deterministic_synthesis=False,
):
if input.size(0) == 0:
return
- to_generate = (ar_mask.sum(0) > 0).nonzero()
+ mask = (mask_ar > 0).long()
+ to_generate = (mask.sum(0) > 0).nonzero()
+
+ indices_1 = list(((mask_ar == 1).long().sum(0) > 0).nonzero()) + [mask.size(1)]
if to_generate.min() > 0:
model(
BracketedSequence(input, 0, to_generate.min())
) # Needed to initialize the model's cache
- for s in range(to_generate.min(), to_generate.max() + 1):
- output = model(BracketedSequence(input, s, 1)).x
- logits = output[:, s]
+ s = to_generate.min()
+
+ for s, u in zip(indices_1[:-1], indices_1[1:]):
+ logits = model(BracketedSequence(input, s, u - s)).x
if deterministic_synthesis:
- t_next = logits.argmax(-1)
+ t_next = logits.argmax(dim=2)
else:
dist = torch.distributions.categorical.Categorical(logits=logits)
t_next = dist.sample()
- all_n = torch.arange(t_next.size(0))
-
- acc_seq_logprobas += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next]
+ acc_seq_logprobas += (
+ mask
+ * logits.log_softmax(dim=1).gather(dim=2, index=t_next[:, :, None])[:, :, 0]
+ ).sum(dim=1)
- input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+ input[...] = mask * t_next + (1 - mask) * input
######################################################################
self.answer_len = None
self.prompt_noise = prompt_noise
- # struct, mask_generate, mask_noise, mask_loss
+ # - struct, quad_generate, quad_noise, quad_loss
self.train_structures = [
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
- (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
- (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
- (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
+ (("f_A", "A", "f_B", "B"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
+ (("B", "f_B", "A", "f_A"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
+ (("f_B", "B", "f_A", "A"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
(("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
]
self,
model,
input,
- ar_mask,
+ mask_ar,
seq_logprobas,
progress_bar_desc=None,
):
- assert input.size() == ar_mask.size()
+ assert input.size() == mask_ar.size()
batches = zip(
input.split(self.batch_size),
- ar_mask.split(self.batch_size),
+ mask_ar.split(self.batch_size),
seq_logprobas.split(self.batch_size),
)
t = model.training
model.eval()
- for input, ar_mask, seq_logprobas in batches:
+ for input, mask_ar, seq_logprobas in batches:
one_batch_masked_inplace_autoregression(
model=model,
input=input,
- ar_mask=ar_mask,
+ mask_ar=mask_ar,
acc_seq_logprobas=seq_logprobas,
deterministic_synthesis=False,
)
quizzes, structs=[s for s, _, _, _ in self.train_structures]
)
+ quiz_mask_ar = quizzes.new_full(quizzes.size(), 1)
quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
- if self.prompt_noise > 0.0:
- for struct, _, mask_noise, mask_loss in self.train_structures:
- i = self.problem.indices_select(quizzes=quizzes, struct=struct)
- if i.any():
+ for struct, quad_ar, quad_noise, quad_loss in self.train_structures:
+ i = self.problem.indices_select(quizzes=quizzes, struct=struct)
+ if i.any():
+ if self.prompt_noise > 0.0:
quizzes[i] = self.problem.inject_noise(
- quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise
- )
- quiz_mask_loss[i] = self.make_quiz_mask(
- quizzes=quizzes[i], struct=struct, mask=mask_loss
+ quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise
)
+ quiz_mask_ar[i] = self.make_quiz_mask(
+ quizzes=quizzes[i], struct=struct, quad=quad_ar
+ )
+ quiz_mask_loss[i] = self.make_quiz_mask(
+ quizzes=quizzes[i], struct=struct, quad=quad_loss
+ )
- return quizzes, quiz_mask_loss
+ return quizzes, quiz_mask_ar, quiz_mask_loss
######################################################################
- def make_quiz_mask(self, quizzes, struct, mask):
+ def make_quiz_mask(self, quizzes, struct, quad):
assert struct in [s for s, _, _, _ in self.train_structures]
- return self.problem.make_quiz_mask(quizzes, struct=struct, mask=mask)
+ return self.problem.make_quiz_mask(quizzes, struct=struct, quad=quad)
######################################################################
- def predict(self, model, quizzes, struct, mask):
+ def predict(self, model, quizzes, struct, quad_ar):
quizzes = quizzes.to(self.device)
- ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask)
- result = quizzes * (1 - ar_mask)
+ mask_ar = self.make_quiz_mask(quizzes=quizzes, struct=struct, quad=quad_ar)
+ result = quizzes * (mask_ar == 0).long()
seq_logprobas = torch.zeros(quizzes.size(0), device=self.device)
self.autoregression(
model=model,
input=result,
- ar_mask=ar_mask,
+ mask_ar=mask_ar,
seq_logprobas=seq_logprobas,
progress_bar_desc="autoregression",
)
nb = 0
# We consider all the configurations that we train for
- for struct, mask_generate, _, _ in self.test_structures:
+ for struct, quad_ar, _, _ in self.test_structures:
i = self.problem.indices_select(quizzes=input, struct=struct)
nb += i.long().sum()
result[i], correct[i], _ = self.predict(
- model=model, quizzes=input[i], struct=struct, mask=mask_generate
+ model=model, quizzes=input[i], struct=struct, quad=quad_ar
)
- predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[
- None, :
- ]
+ predicted_parts[i] = torch.tensor(quad_ar, device=self.device)[None, :]
solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1
correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long()
self.autoregression(
model=model_for_generation,
input=c_quizzes,
- ar_mask=self.make_quiz_mask(c_quizzes, s, m),
+ mask_ar=self.make_quiz_mask(c_quizzes, s, m),
seq_logprobas=seq_logprobas,
progress_bar_desc=f"autoregression {n_step+1}/{len(procedure)}",
)