X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=beaver.py;h=6e1eaf4137f5a3bc184f4bd8fd529adbd1791745;hb=519b5419b30de82828a41b620e76d993a70423e9;hp=f5b3563ab1cff1ec1c94b1e0668cccd9e41dcbc1;hpb=a113de0d0ba103b6fb1bfdec69b550147a2a262f;p=beaver.git diff --git a/beaver.py b/beaver.py index f5b3563..6e1eaf4 100755 --- a/beaver.py +++ b/beaver.py @@ -233,7 +233,33 @@ def oneshot_trace_loss(mazes, output, policies, height, width): return (output - targets).abs().sum() / masks.sum() -def oneshot(gpt, learning_rate_scheduler, task): +def oneshot(model, learning_rate_scheduler, task): + t = model.training + model.eval() + mazes = task.test_input[:32].clone() + mazes[:, task.height * task.width :] = 0 + output = eval_mygpt(model, mazes, prompt_len=task.height * task.width) + output = F.softmax(output, dim=2) + print(f"{output.size()=}") + proba_path = output[:, task.height * task.width :, 4].reshape( + -1, task.height, task.width + ) + mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width) + paths = task.test_input[:32, task.height * task.width :].reshape( + -1, task.height, task.width + ) + filename = f"oneshot.png" + maze.save_image( + os.path.join(args.result_dir, filename), + mazes=mazes, + target_paths=paths, + score_paths=proba_path, + # score_truth=targets, + ) + log_string(f"wrote {filename}") + + +def oneshot_old(gpt, learning_rate_scheduler, task): t = gpt.training gpt.eval() @@ -265,6 +291,8 @@ def oneshot(gpt, learning_rate_scheduler, task): for n_epoch in range(args.nb_epochs): learning_rate = learning_rate_scheduler.get_learning_rate() + log_string(f"learning_rate {n_epoch} {learning_rate}") + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) acc_train_loss, nb_train_samples = 0, 0 @@ -596,14 +624,24 @@ def noncausal_prompt_amm_generator(d): q = torch.arange(d)[:, None] k = torch.arange(d)[None, :] s = args.maze_height * args.maze_width - # return torch.logical_and(q < k, torch.logical_or(q >= s, k >= s)) - return q < k + return torch.logical_and(q < k, torch.logical_or(q >= s, k >= s)) + # return q < k -amm_generator = None +def noncausal_prompt_oneshot_amm_generator(d): + q = torch.arange(d)[:, None] + k = torch.arange(d)[None, :] + s = args.maze_height * args.maze_width + return k >= s + # return q < k -if args.noncausal_prompt: + +if args.oneshot: + amm_generator = noncausal_prompt_oneshot_amm_generator +elif args.noncausal_prompt: amm_generator = noncausal_prompt_amm_generator +else: + amm_generator = None model = mygpt.MyGPT( vocabulary_size=vocabulary_size, @@ -681,6 +719,12 @@ else: ###################################################################### +if args.oneshot: + oneshot(model, learning_rate_scheduler, task) + exit(0) + +###################################################################### + token_count = 0 for input in task.batches(split="train"): token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1)) @@ -711,8 +755,7 @@ learning_rate_scheduler.reset() for n_epoch in range(nb_epochs_finished, args.nb_epochs): learning_rate = learning_rate_scheduler.get_learning_rate() - - log_string(f"learning_rate {learning_rate}") + log_string(f"learning_rate {n_epoch} {learning_rate}") if args.optim == "sgd": optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) @@ -770,8 +813,3 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): log_string(f"saved checkpoint {checkpoint_name}") ###################################################################### - -if args.oneshot: - oneshot(model, learning_rate_scheduler, task) - -######################################################################