From 92a2935401f6fe21efe19ac3c476521665a242a5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 7 Apr 2023 14:00:09 +0200 Subject: [PATCH] Update --- beaver.py | 54 ++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/beaver.py b/beaver.py index 7800527..6ed9dd2 100755 --- a/beaver.py +++ b/beaver.py @@ -233,7 +233,30 @@ def oneshot_trace_loss(mazes, output, policies, height, width): return (output - targets).abs().sum() / masks.sum() -def oneshot(gpt, learning_rate_scheduler, task): +def oneshot(model, learning_rate_scheduler, task): + t = model.training + model.eval() + mazes = task.test_input[:32].clone() + mazes[:, task.height * task.width :] = 0 + output = eval_mygpt(model, mazes, prompt_len=task.height * task.width) + output = F.softmax(output, dim=2) + print(f"{output.size()=}") + proba_path = output[:, task.height * task.width :, 4].reshape( + -1, task.height, task.width + ) + mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width) + # targets = targets.reshape(-1, task.height, task.width) + filename = f"oneshot.png" + maze.save_image( + os.path.join(args.result_dir, filename), + mazes=mazes, + score_paths=proba_path, + # score_truth=targets, + ) + log_string(f"wrote {filename}") + + +def oneshot_old(gpt, learning_rate_scheduler, task): t = gpt.training gpt.eval() @@ -598,14 +621,24 @@ def noncausal_prompt_amm_generator(d): q = torch.arange(d)[:, None] k = torch.arange(d)[None, :] s = args.maze_height * args.maze_width - # return torch.logical_and(q < k, torch.logical_or(q >= s, k >= s)) - return q < k + return torch.logical_and(q < k, torch.logical_or(q >= s, k >= s)) + # return q < k + +def noncausal_prompt_oneshot_amm_generator(d): + q = torch.arange(d)[:, None] + k = torch.arange(d)[None, :] + s = args.maze_height * args.maze_width + return k >= s + # return q < k -amm_generator = None -if args.noncausal_prompt: +if args.oneshot: + amm_generator = noncausal_prompt_oneshot_amm_generator +elif args.noncausal_prompt: amm_generator = noncausal_prompt_amm_generator +else: + amm_generator = None model = mygpt.MyGPT( vocabulary_size=vocabulary_size, @@ -683,6 +716,12 @@ else: ###################################################################### +if args.oneshot: + oneshot(model, learning_rate_scheduler, task) + exit(0) + +###################################################################### + token_count = 0 for input in task.batches(split="train"): token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1)) @@ -771,8 +810,3 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): log_string(f"saved checkpoint {checkpoint_name}") ###################################################################### - -if args.oneshot: - oneshot(model, learning_rate_scheduler, task) - -###################################################################### -- 2.39.5