mask_ar = quizzes.new_zeros(quizzes.size())
S = self.height * self.width
- a = mask_ar.reshape(mask_ar.size(0), 4, S + 1)[:, :, 1:]
+ a = mask_ar.view(mask_ar.size(0), 4, S + 1)[:, :, 1:]
a[:, 0, :] = quad[0]
a[:, 1, :] = quad[1]
a[:, 2, :] = quad[2]
output = model(
mygpt.BracketedSequence(input, ranks=mygpt.mask_ar_to_ranks(mask_ar))
).x
+
loss_per_token = F.cross_entropy(
output.transpose(1, 2), targets, reduction="none"
)
######################################################################
+def model_proba_solutions(model, quizzes):
+ l = (
+ quiz_machine.models_logprobas(
+ model,
+ quizzes,
+ ("A", "f_A", "B", "f_B"),
+ (0, 0, 0, 2),
+ (0, 0, 1, 0),
+ (0, 0, 0, 1),
+ )
+ + quiz_machine.models_logprobas(
+ model,
+ quizzes,
+ ("f_A", "A", "f_B", "B"),
+ (0, 0, 0, 2),
+ (0, 0, 1, 0),
+ (0, 0, 0, 1),
+ )
+ + quiz_machine.models_logprobas(
+ model,
+ quizzes,
+ ("B", "f_B", "A", "f_A"),
+ (0, 0, 0, 2),
+ (0, 0, 1, 0),
+ (0, 0, 0, 1),
+ )
+ + quiz_machine.models_logprobas(
+ model,
+ quizzes,
+ ("f_B", "B", "f_A", "A"),
+ (0, 0, 0, 2),
+ (0, 0, 1, 0),
+ (0, 0, 0, 1),
+ )
+ )
+
+ return l.exp()
+
+
def save_additional_results(n_epoch, model, models, c_quizzes_procedure):
# Save generated quizzes with the successive generation steps
# This is nb_quizzes x nb_models
- l = [
- quiz_machine.models_logprobas(
- model, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
- + quiz_machine.models_logprobas(
- model, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
- + quiz_machine.models_logprobas(
- model, c_quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
- + quiz_machine.models_logprobas(
- model, c_quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
- for model in models
- ]
+ l = [model_proba_solutions(model, c_quizzes) for model in models]
seq_logprobas = torch.cat([x[:, None] for x in l], dim=1)
probas = seq_logprobas.exp()
######################################################################
-def model_proba_solutions(model, quizzes):
- l = (
- quiz_machine.models_logprobas(
- model, quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
- + quiz_machine.models_logprobas(
- model, quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
- + quiz_machine.models_logprobas(
- model, quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
- + quiz_machine.models_logprobas(
- model, quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
- )
-
- return l.exp()
-
-
def create_c_quizzes(
main_model,
other_models,
return input
-######################################################################
-
-
-def save_generated_c_quizzes(model, filename, nb=64):
- while sum([x.size(0) for x in record]) < nb:
- model = models[torch.randint(len(models), (1,)).item()]
- c_quizzes = quiz_machine.generate_c_quizzes(
- 64,
- model_for_generation=model,
- procedure=c_quizzes_procedure,
- )
-
- p = quiz_machine.models_logprobas(
- model,
- c_quizzes,
- ("A", "f_A", "B", "f_B"),
- (1, 1, 1, 1),
- temperature=1,
- ).exp()
-
- p_hot = quiz_machine.models_logprobas(
- model,
- c_quizzes,
- ("A", "f_A", "B", "f_B"),
- (1, 1, 1, 1),
- temperature=args.temperature_hot,
- ).exp()
-
- to_keep = p_hot * torch.rand(p_hot.size(), device=p_hot.device) >= p
- record.append(c_quizzes[to_keep])
-
- print("NB_KEPT", sum([x.size(0) for x in record]))
-
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- filename,
- quizzes=c_quizzes,
- )
-
- log_string(f"wrote {filename}")
-
-
######################################################################
for n_epoch in range(current_epoch, args.nb_epochs):
return a
+# mask_ar = torch.tensor([[ 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1]])
+# print(mask_ar)
+# print(mask_ar_to_ranks(mask_ar))
+# exit(0)
+
+
class BracketedSequence:
def __init__(self, x, first=None, nb=None, ranks=None):
self.x = x
dim_qk,
dim_v,
nb_heads=1,
+ first_one=False,
attention_dropout=0.0,
):
super().__init__()
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
self.attention_dropout = attention_dropout
+ self.first_one = first_one
+
self.record_attention = False
self.w_q = randw(nb_heads, dim_qk, dim_in)
t = torch.arange(x_q.size(1), device=a.device)
- if bs_q.ranks is not None:
- a = a.masked_fill(
- (
- bs_q.ranks[:, None, bs_q.first : bs_q.first + bs_q.nb, None]
- <= bs_kv.ranks[:, None, None, : bs_kv.first + bs_kv.nb]
- )
- & (
- t[None, None, bs_q.first : bs_q.first + bs_q.nb, None]
- != t[None, None, None, : bs_kv.first + bs_kv.nb]
- ),
- float("-inf"),
+ assert bs_q.ranks is not None
+
+ # rank_forward = (
+ # bs_q.ranks[:, None, bs_q.first : bs_q.first + bs_q.nb, None]
+ # >= bs_kv.ranks[:, None, None, : bs_kv.first + bs_kv.nb]
+ # )
+
+ if self.first_one:
+ rank_forward = (
+ t[None, None, bs_q.first : bs_q.first + bs_q.nb, None]
+ <= t[None, None, None, : bs_kv.first + bs_kv.nb]
+ )
+ else:
+ rank_forward = (
+ t[None, None, bs_q.first : bs_q.first + bs_q.nb, None]
+ < t[None, None, None, : bs_kv.first + bs_kv.nb]
)
+ a = a.masked_fill(rank_forward, float("-inf"))
+
a = a.softmax(dim=3)
if self.record_attention:
self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
- return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
+ return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb, bs_q.ranks)
##############################
self.positional_encoding = AddPositionalEncoding(len_max)
- trunk_blocks = []
+ trunk_blocks = [
+ QKVAttention(
+ dim_in=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ first_one=True,
+ attention_dropout=dropout,
+ )
+ ]
for b in range(nb_blocks):
trunk_blocks += [
for m in self.modules():
m.loss = 0
- bs = self.shifter(bs)
+ # bs = self.shifter(bs)
bs = self.embedding(bs)
bs = self.positional_encoding(bs)
bs = self.trunk(bs)
indices_1 = list(((mask_ar == 1).long().sum(0) > 0).nonzero()) + [mask.size(1)]
+ ranks = mygpt.mask_ar_to_ranks(mask_ar)
+
if to_generate.min() > 0:
model(
- BracketedSequence(input, 0, to_generate.min())
+ BracketedSequence(input, 0, to_generate.min(), ranks=ranks)
) # Needed to initialize the model's cache
- s = to_generate.min()
-
for s, u in zip(indices_1[:-1], indices_1[1:]):
- logits = model(
- BracketedSequence(input, s, u - s, ranks=mygpt.mask_ar_to_ranks(mask_ar))
- ).x
+ logits = model(BracketedSequence(input, s, u - s, ranks=ranks)).x
if deterministic_synthesis:
t_next = logits.argmax(dim=2)
# - struct, quad_generate, quad_noise, quad_loss
self.train_structures = [
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
- (("f_A", "A", "f_B", "B"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
- (("B", "f_B", "A", "f_A"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
- (("f_B", "B", "f_A", "A"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+ (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+ (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+ (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
(("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
]
model,
c_quizzes,
struct,
+ mask_ar,
+ mask_noise,
mask_loss,
- mask_noise=None,
temperature=1.0,
device=None,
):
for input, l in zip(
c_quizzes.split(self.batch_size),
+ mask_ar.split(self.batch_size),
seq_logprobas.split(self.batch_size),
):
input = input.to(device)
quiz_mask_loss = self.make_quiz_mask(
input, struct=struct, mask=mask_loss
)
- output = model(mygpt.BracketedSequence(input)).x / temperature
+ output = (
+ model(
+ mygpt.BracketedSequence(input),
+ ranks=mygpt.mask_ar_to_ranks(mask_ar),
+ ).x
+ / temperature
+ )
l[...] = (
-F.cross_entropy(output.transpose(1, 2), input, reduction="none")
* quiz_mask_loss