######################################################################
-def mask_ar_to_ranks(mask_ar):
- a = (mask_ar < 2).long()
- a = a.cumsum(dim=1) - a
- b = ((mask_ar[:, :-1] == 2) & (mask_ar[:, 1:] != 2)).long().cumsum(dim=1)
- a[:, 1:] += b
- return a
-
-
def run_tests(model, quiz_machine, local_device=main_device):
with torch.autograd.no_grad():
model.to(local_device).eval()
targets = input
output = model(
- mygpt.BracketedSequence(input, ranks=mask_ar_to_ranks(mask_ar))
+ mygpt.BracketedSequence(input, ranks=mygpt.mask_ar_to_ranks(mask_ar))
).x
loss_per_token = F.cross_entropy(
output.transpose(1, 2), targets, reduction="none"
targets = input
output = model(
- mygpt.BracketedSequence(input, ranks=mask_ar_to_ranks(mask_ar))
+ mygpt.BracketedSequence(input, ranks=mygpt.mask_ar_to_ranks(mask_ar))
).x
loss_per_token = F.cross_entropy(
# resetted when the input bracket starts at t=0
+def mask_ar_to_ranks(mask_ar):
+ a = (mask_ar < 2).long()
+ a = a.cumsum(dim=1) - a
+ b = ((mask_ar[:, :-1] == 2) & (mask_ar[:, 1:] != 2)).long().cumsum(dim=1)
+ a[:, 1:] += b
+ return a
+
+
class BracketedSequence:
def __init__(self, x, first=None, nb=None, ranks=None):
self.x = x
s = to_generate.min()
for s, u in zip(indices_1[:-1], indices_1[1:]):
- logits = model(BracketedSequence(input, s, u - s)).x
+ logits = model(
+ BracketedSequence(input, s, u - s, ranks=mygpt.mask_ar_to_ranks(mask_ar))
+ ).x
if deterministic_synthesis:
t_next = logits.argmax(dim=2)
i = self.problem.indices_select(quizzes=input, struct=struct)
nb += i.long().sum()
result[i], correct[i], _ = self.predict(
- model=model, quizzes=input[i], struct=struct, quad=quad_ar
+ model=model, quizzes=input[i], struct=struct, quad_ar=quad_ar
)
predicted_parts[i] = torch.tensor(quad_ar, device=self.device)[None, :]