parser.add_argument("--nb_epochs", type=int, default=25)
-parser.add_argument("--batch_size", type=int, default=100)
+parser.add_argument("--nb_train_samples", type=int, default=200000)
-parser.add_argument("--data_size", type=int, default=-1)
+parser.add_argument("--nb_test_samples", type=int, default=50000)
+
+parser.add_argument("--batch_size", type=int, default=25)
parser.add_argument("--optim", type=str, default="adam")
s = s.reshape(s.size(0), -1, self.height, self.width)
return (s[:, k] for k in range(s.size(1)))
- def __init__(self, batch_size, height, width, nb_walls, device=torch.device("cpu")):
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ height,
+ width,
+ nb_walls,
+ device=torch.device("cpu"),
+ ):
self.batch_size = batch_size
self.height = height
self.width = width
self.device = device
- nb = args.data_size if args.data_size > 0 else 250000
-
mazes_train, paths_train = maze.create_maze_data(
- (4 * nb) // 5,
+ nb_train_samples,
height=height,
width=width,
nb_walls=nb_walls,
self.nb_codes = self.train_input.max() + 1
mazes_test, paths_test = maze.create_maze_data(
- nb // 5,
+ nb_test_samples,
height=height,
width=width,
nb_walls=nb_walls,
mazes_test, paths_test = mazes_test.to(device), paths_test.to(device)
self.test_input = self.map2seq(mazes_test, paths_test)
- def batches(self, split="train"):
+ def batches(self, split="train", nb_to_use=-1):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
+ if nb_to_use > 0:
+ input = input[:nb_to_use]
for batch in tqdm.tqdm(
input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
):
def vocabulary_size(self):
return self.nb_codes
- def compute_error(self, model, split="train"):
+ def compute_error(self, model, split="train", nb_to_use=-1):
nb_total, nb_correct = 0, 0
- for input in task.batches(split):
+ for input in task.batches(split, nb_to_use):
result = input.clone()
ar_mask = result.new_zeros(result.size())
ar_mask[:, self.height * self.width :] = 1
return nb_total, nb_correct
def produce_results(self, n_epoch, model):
- train_nb_total, train_nb_correct = self.compute_error(model, "train")
- log_string(
- f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
- )
-
- test_nb_total, test_nb_correct = self.compute_error(model, "test")
- log_string(
- f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
- )
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ train_nb_total, train_nb_correct = self.compute_error(
+ model, "train", nb_to_use=1000
+ )
+ log_string(
+ f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+ )
+
+ test_nb_total, test_nb_correct = self.compute_error(
+ model, "test", nb_to_use=1000
+ )
+ log_string(
+ f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+ )
+
+ input = self.test_input[:32]
+ result = input.clone()
+ ar_mask = result.new_zeros(result.size())
- input = self.test_input[:32]
- result = input.clone()
- ar_mask = result.new_zeros(result.size())
+ ar_mask[:, self.height * self.width :] = 1
+ masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
- ar_mask[:, self.height * self.width :] = 1
- masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
+ mazes, paths = self.seq2map(input)
+ _, predicted_paths = self.seq2map(result)
+ maze.save_image(f"result_{n_epoch:04d}.png", mazes, paths, predicted_paths)
- mazes, paths = self.seq2map(input)
- _, predicted_paths = self.seq2map(result)
- maze.save_image(f"result_{n_epoch:04d}.png", mazes, paths, predicted_paths)
+ model.train(t)
######################################################################
task = TaskMaze(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
height=args.world_height,
width=args.world_width,