formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
-parser.add_argument("--task", type=str, default="picoclvr")
+parser.add_argument(
+ "--task", type=str, default="picoclvr", help="picoclvr, mnist, maze, snake"
+)
-parser.add_argument("--log_filename", type=str, default="train.log")
+parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
parser.add_argument("--result_dir", type=str, default="results_default")
progress_bar_desc="autoregression",
device=torch.device("cpu"),
):
+ # p = logits.softmax(1)
+ # entropy[:,s]= p.xlogy(p).sum(1) / math.log(2)
batches = zip(input.split(batch_size), ar_mask.split(batch_size))
if progress_bar_desc is not None:
tqdm.tqdm(
image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png")
torchvision.utils.save_image(
- img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0
+ img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
)
log_string(f"wrote {image_name}")
def compute_error(self, model, split="train", nb_to_use=-1):
nb_total, nb_correct = 0, 0
- for input in task.batches(split, nb_to_use):
+ count = torch.zeros(
+ self.width * self.height,
+ self.width * self.height,
+ device=self.device,
+ dtype=torch.int64,
+ )
+ for input in tqdm.tqdm(
+ task.batches(split, nb_to_use),
+ dynamic_ncols=True,
+ desc=f"test-mazes",
+ ):
result = input.clone()
ar_mask = result.new_zeros(result.size())
ar_mask[:, self.height * self.width :] = 1
result *= 1 - ar_mask
masked_inplace_autoregression(
- model, self.batch_size, result, ar_mask, device=self.device
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ progress_bar_desc=None,
+ device=self.device,
)
mazes, paths = self.seq2map(result)
- nb_correct += maze.path_correctness(mazes, paths).long().sum()
+ path_correctness = maze.path_correctness(mazes, paths)
+ nb_correct += path_correctness.long().sum()
nb_total += mazes.size(0)
- return nb_total, nb_correct
+ optimal_path_lengths = (
+ (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
+ )
+ predicted_path_lengths = (
+ (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
+ )
+ optimal_path_lengths = optimal_path_lengths[path_correctness]
+ predicted_path_lengths = predicted_path_lengths[path_correctness]
+ count[optimal_path_lengths, predicted_path_lengths] += 1
+
+ if count.max() == 0:
+ count = None
+ else:
+ count = count[
+ : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
+ ]
+
+ return nb_total, nb_correct, count
def produce_results(self, n_epoch, model):
with torch.autograd.no_grad():
t = model.training
model.eval()
- train_nb_total, train_nb_correct = self.compute_error(
+ train_nb_total, train_nb_correct, count = self.compute_error(
model, "train", nb_to_use=1000
)
log_string(
f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
- test_nb_total, test_nb_correct = self.compute_error(
+ test_nb_total, test_nb_correct, count = self.compute_error(
model, "test", nb_to_use=1000
)
log_string(
f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
+ if count is not None:
+ proportion_optimal = count.diagonal().sum().float() / count.sum()
+ log_string(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
+ with open(
+ os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
+ ) as f:
+ for i in range(count.size(0)):
+ for j in range(count.size(1)):
+ eol = " " if j < count.size(1) - 1 else "\n"
+ f.write(f"{count[i,j]}{eol}")
+
input = self.test_input[:48]
result = input.clone()
ar_mask = result.new_zeros(result.size())
target_paths=paths,
predicted_paths=predicted_paths,
path_correct=maze.path_correctness(mazes, predicted_paths),
+ path_optimal=maze.path_optimality(paths, predicted_paths),
)
log_string(f"wrote {filename}")