From deecc7add8b26a0bafb6de49777abd2ed04fff94 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 20 Aug 2024 22:44:49 +0200 Subject: [PATCH] Update. --- main.py | 12 ++---------- mygpt.py | 8 ++++++++ quiz_machine.py | 6 ++++-- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index 8908613..78d01ff 100755 --- a/main.py +++ b/main.py @@ -327,14 +327,6 @@ def optimizer_to(optim, device): ###################################################################### -def mask_ar_to_ranks(mask_ar): - a = (mask_ar < 2).long() - a = a.cumsum(dim=1) - a - b = ((mask_ar[:, :-1] == 2) & (mask_ar[:, 1:] != 2)).long().cumsum(dim=1) - a[:, 1:] += b - return a - - def run_tests(model, quiz_machine, local_device=main_device): with torch.autograd.no_grad(): model.to(local_device).eval() @@ -366,7 +358,7 @@ def run_tests(model, quiz_machine, local_device=main_device): targets = input output = model( - mygpt.BracketedSequence(input, ranks=mask_ar_to_ranks(mask_ar)) + mygpt.BracketedSequence(input, ranks=mygpt.mask_ar_to_ranks(mask_ar)) ).x loss_per_token = F.cross_entropy( output.transpose(1, 2), targets, reduction="none" @@ -427,7 +419,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): targets = input output = model( - mygpt.BracketedSequence(input, ranks=mask_ar_to_ranks(mask_ar)) + mygpt.BracketedSequence(input, ranks=mygpt.mask_ar_to_ranks(mask_ar)) ).x loss_per_token = F.cross_entropy( diff --git a/mygpt.py b/mygpt.py index c69c899..cd5b580 100755 --- a/mygpt.py +++ b/mygpt.py @@ -75,6 +75,14 @@ class RandomBypass(nn.Module): # resetted when the input bracket starts at t=0 +def mask_ar_to_ranks(mask_ar): + a = (mask_ar < 2).long() + a = a.cumsum(dim=1) - a + b = ((mask_ar[:, :-1] == 2) & (mask_ar[:, 1:] != 2)).long().cumsum(dim=1) + a[:, 1:] += b + return a + + class BracketedSequence: def __init__(self, x, first=None, nb=None, ranks=None): self.x = x diff --git a/quiz_machine.py b/quiz_machine.py index d209a07..8cec909 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -47,7 +47,9 @@ def one_batch_masked_inplace_autoregression( s = to_generate.min() for s, u in zip(indices_1[:-1], indices_1[1:]): - logits = model(BracketedSequence(input, s, u - s)).x + logits = model( + BracketedSequence(input, s, u - s, ranks=mygpt.mask_ar_to_ranks(mask_ar)) + ).x if deterministic_synthesis: t_next = logits.argmax(dim=2) @@ -228,7 +230,7 @@ class QuizMachine: i = self.problem.indices_select(quizzes=input, struct=struct) nb += i.long().sum() result[i], correct[i], _ = self.predict( - model=model, quizzes=input[i], struct=struct, quad=quad_ar + model=model, quizzes=input[i], struct=struct, quad_ar=quad_ar ) predicted_parts[i] = torch.tensor(quad_ar, device=self.device)[None, :] -- 2.39.5