From: François Fleuret Date: Mon, 20 Mar 2023 20:52:20 +0000 (+0100) Subject: Update X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=beaver.git;a=commitdiff_plain;h=2e22edce168279392e9f76330c585730d364538e Update --- diff --git a/beaver.py b/beaver.py index b505156..f62c749 100755 --- a/beaver.py +++ b/beaver.py @@ -173,13 +173,21 @@ def compute_perplexity(model, split="train"): ###################################################################### -def oneshot_policy_loss(output, policies, mask): - targets = policies.permute(0, 2, 1) * mask.unsqueeze(-1) - output = output * mask.unsqueeze(-1) - return -(output.log_softmax(-1) * targets).sum() / mask.sum() +def oneshot_policy_loss(mazes, output, policies, height, width): + masks = (mazes == maze.v_empty).unsqueeze(-1) + targets = policies.permute(0, 2, 1) * masks + output = output * masks + return -(output.log_softmax(-1) * targets).sum() / masks.sum() -# loss = (output.softmax(-1) - targets).abs().max(-1).values.mean() +def oneshot_trace_loss(mazes, output, policies, height, width): + masks = mazes == maze.v_empty + targets = maze.stationary_densities( + mazes.view(-1, height, width), policies.view(-1, 4, height, width) + ).flatten(-2) + targets = targets * masks + output = output.squeeze(-1) * masks + return (output - targets).abs().sum() / masks.sum() def oneshot(gpt, task): @@ -198,6 +206,7 @@ def oneshot(gpt, task): compute_loss = oneshot_policy_loss elif args.oneshot_output == "trace": dim_out = 1 + compute_loss = oneshot_trace_loss else: raise ValueError(f"{args.oneshot_output=}") @@ -206,7 +215,7 @@ def oneshot(gpt, task): nn.ReLU(), nn.Linear(args.dim_model, args.dim_model), nn.ReLU(), - nn.Linear(args.dim_model, 4), + nn.Linear(args.dim_model, dim_out), ).to(device) for n_epoch in range(args.nb_epochs): @@ -214,54 +223,66 @@ def oneshot(gpt, task): optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) acc_train_loss, nb_train_samples = 0, 0 - for input, policies in task.policy_batches(split="train"): + for mazes, policies in task.policy_batches(split="train"): #### - # print(f'{input.size()=} {policies.size()=}') + # print(f'{mazes.size()=} {policies.size()=}') # s = maze.stationary_densities( # exit(0) #### - mask = input == maze.v_empty - output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_input).x + masks = mazes == maze.v_empty + output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x output = model(output_gpt) - loss = compute_loss(output, policies, mask) - acc_train_loss += loss.item() * input.size(0) - nb_train_samples += input.size(0) + loss = compute_loss(mazes, output, policies, task.height, task.width) + acc_train_loss += loss.item() * mazes.size(0) + nb_train_samples += mazes.size(0) optimizer.zero_grad() loss.backward() optimizer.step() acc_test_loss, nb_test_samples = 0, 0 - for input, policies in task.policy_batches(split="test"): - mask = input == maze.v_empty - output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_input).x + for mazes, policies in task.policy_batches(split="test"): + output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x output = model(output_gpt) - loss = compute_loss(output, policies, mask) - acc_test_loss += loss.item() * input.size(0) - nb_test_samples += input.size(0) + loss = compute_loss(mazes, output, policies, task.height, task.width) + acc_test_loss += loss.item() * mazes.size(0) + nb_test_samples += mazes.size(0) log_string( f"diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples}" ) # ------------------- - input = task.test_input[:32, : task.height * task.width] - targets = task.test_policies[:32].permute(0, 2, 1) - output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_input).x + mazes = task.test_input[:32, : task.height * task.width] + policies = task.test_policies[:32] + output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x output = model(output_gpt) - scores = ( - (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0 - ).float() + if args.oneshot_output == "policy": + targets = policies.permute(0, 2, 1) + scores = ( + (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0 + ).float() + elif args.oneshot_output == "trace": + targets = maze.stationary_densities( + mazes.view(-1, task.height, task.width), + policies.view(-1, 4, task.height, task.width), + ).flatten(-2) + scores = output.flatten(-2) + else: + raise ValueError(f"{args.oneshot_output=}") + scores = scores.reshape(-1, task.height, task.width) - input = input.reshape(-1, task.height, task.width) + mazes = mazes.reshape(-1, task.height, task.width) + targets = targets.reshape(-1, task.height, task.width) maze.save_image( os.path.join( args.result_dir, f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png", ), - mazes=input, + mazes=mazes, score_paths=scores, + score_truth=targets, ) # ------------------- diff --git a/maze.py b/maze.py index d09e860..81afcd9 100755 --- a/maze.py +++ b/maze.py @@ -98,10 +98,10 @@ def compute_policy(walls, goal_i, goal_j): distance = distance + walls.numel() * walls value = distance.new_full((4,) + distance.size(), walls.numel()) - value[0, :, 1:] = distance[:, :-1] - value[1, :, :-1] = distance[:, 1:] - value[2, 1:, :] = distance[:-1, :] - value[3, :-1, :] = distance[1:, :] + value[0, :, 1:] = distance[:, :-1] # < + value[1, :, :-1] = distance[:, 1:] # > + value[2, 1:, :] = distance[:-1, :] # ^ + value[3, :-1, :] = distance[1:, :] # v proba = (value.min(dim=0)[0][None] == value).float() proba = proba / proba.sum(dim=0)[None] @@ -111,18 +111,19 @@ def compute_policy(walls, goal_i, goal_j): def stationary_densities(mazes, policies): + policies = policies * (mazes != v_goal)[:, None] start = (mazes == v_start).nonzero(as_tuple=True) - probas = mazes.new_zeros(mazes.size()) + probas = mazes.new_zeros(mazes.size(), dtype=torch.float32) pred_probas = probas.clone() probas[start] = 1.0 while not pred_probas.equal(probas): pred_probas.copy_(probas) probas.zero_() - probas[:, 1:, :] = pred_probas[:, :-1, :] * policies[:, 0, :-1, :] - probas[:, :-1, :] = pred_probas[:, 1:, :] * policies[:, 1, 1:, :] - probas[:, :, 1:] = pred_probas[:, :, :-1] * policies[:, 2, :, :-1] - probas[:, :, :-1] = pred_probas[:, :, 1:] * policies[:, 3, :, 1:] + probas[:, 1:, :] += pred_probas[:, :-1, :] * policies[:, 3, :-1, :] + probas[:, :-1, :] += pred_probas[:, 1:, :] * policies[:, 2, 1:, :] + probas[:, :, 1:] += pred_probas[:, :, :-1] * policies[:, 1, :, :-1] + probas[:, :, :-1] += pred_probas[:, :, 1:] * policies[:, 0, :, 1:] probas[start] = 1.0 return probas @@ -211,6 +212,7 @@ def save_image( target_paths=None, predicted_paths=None, score_paths=None, + score_truth=None, path_correct=None, ): colors = torch.tensor( @@ -229,6 +231,17 @@ def save_image( colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2) ) + if score_truth is not None: + score_truth = score_truth.cpu() + c_score_truth = score_truth.unsqueeze(1).expand(-1, 3, -1, -1) + c_score_truth = ( + c_score_truth * colors[4].reshape(1, 3, 1, 1) + + (1 - c_score_truth) * colors[0].reshape(1, 3, 1, 1) + ).long() + c_mazes = (mazes.unsqueeze(1) != v_empty) * c_mazes + ( + mazes.unsqueeze(1) == v_empty + ) * c_score_truth + imgs = c_mazes.unsqueeze(1) if target_paths is not None: diff --git a/tensorstack.py b/tensorstack.py index 584c12d..074588e 100755 --- a/tensorstack.py +++ b/tensorstack.py @@ -11,9 +11,9 @@ import sys def exception_hook(exc_type, exc_value, tb): - r"""Hacks the call stack message to show all the local variables in - case of RuntimeError or ValueError, and prints tensors as shape, - dtype and device. + r"""Hacks the call stack message to show all the local variables + in case of relevant error, and prints tensors as shape, dtype and + device. """ @@ -28,7 +28,7 @@ def exception_hook(exc_type, exc_value, tb): print(f' File "{filename}", line {line_no}, in {name}') print(open(filename, "r").readlines()[line_no - 1]) - if exc_type in {RuntimeError, ValueError}: + if exc_type in {RuntimeError, ValueError, IndexError}: for n, v in tb.tb_frame.f_locals.items(): print(f" {n} -> {v}")