From d0561c903270c6e0f95f3abfa5863ff68ac4e45e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 25 Mar 2023 21:02:34 +0100 Subject: [PATCH] Oups. --- beaver.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/beaver.py b/beaver.py index f850f69..065cda0 100755 --- a/beaver.py +++ b/beaver.py @@ -206,8 +206,8 @@ def compute_perplexity(model, task, fixed_len, split="train"): input = input.to(device) output = eval_mygpt(model, input, fixed_len=fixed_len) if args.noncausal_prompt: - t = input.size(1) // 2 - loss = F.cross_entropy(output[:, t:].transpose(1, 2), input[:, t:]) + d = input.size(1) // 2 + loss = F.cross_entropy(output[:, d:].transpose(1, 2), input[:, d:]) else: loss = F.cross_entropy(output.transpose(1, 2), input) acc_loss += loss.item() * input.size(0) @@ -523,16 +523,17 @@ log_string(f"vocabulary_size {vocabulary_size}") ############################## +def noncausal_prompt_amm_generator(d): + q = torch.arange(d)[:, None] + k = torch.arange(d)[None, :] + s = args.maze_height * args.maze_width +# return torch.logical_and(q < k, torch.logical_or(q >= s, k >= s)) + return q < k + amm_generator = None if args.noncausal_prompt: - amm_generator = lambda d: torch.logical_and( - torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :], - torch.logical_or( - torch.arange(d)[None, None, :, None] >= d // 2, - torch.arange(d)[None, None, None, :] >= d // 2, - ), - ) + amm_generator = noncausal_prompt_amm_generator model = mygpt.MyGPT( vocabulary_size=vocabulary_size, @@ -650,11 +651,11 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): for input in task.batches(split="train"): input = input.to(device) output = eval_mygpt( - model, input, mode=args.oneshot_input, fixed_len=task.height * task.width + model, input, fixed_len=task.height * task.width ) if args.noncausal_prompt: - t = input.size(1) // 2 - loss = F.cross_entropy(output[:, t:].transpose(1, 2), input[:, t:]) + d = input.size(1) // 2 + loss = F.cross_entropy(output[:, d:].transpose(1, 2), input[:, d:]) else: loss = F.cross_entropy(output.transpose(1, 2), input) acc_train_loss += loss.item() * input.size(0) -- 2.39.5