"--task",
type=str,
default="picoclvr",
- help="picoclvr, mnist, maze, snake, stack, expr",
+ help="picoclvr, mnist, maze, snake, stack, expr, world",
)
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
parser.add_argument("--snake_length", type=int, default=200)
##############################
-# Snake options
+# Stack options
parser.add_argument("--stack_nb_steps", type=int, default=100)
"nb_train_samples": 1000000,
"nb_test_samples": 10000,
},
+ "world": {
+ "nb_epochs": 5,
+ "batch_size": 25,
+ "nb_train_samples": 10000,
+ "nb_test_samples": 1000,
+ },
}
if args.task in default_args:
device=device,
)
+elif args.task == "world":
+ task = tasks.World(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ device=device,
+ )
+
else:
raise ValueError(f"Unknown task {args.task}")
)
result *= 1 - ar_mask
- # snake.solver(result,ar_mask)
-
masked_inplace_autoregression(
model,
self.batch_size,
nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
- # nb_total = result.size(0)
- # nb_correct = ((result - input).abs().sum(1) == 0).sum()
-
return nb_total, nb_correct
- # train_nb_total, train_nb_correct = compute_nb_correct(
- # self.train_input, self.train_prior_visits
- # )
-
- # logger(
- # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
- # )
-
test_nb_total, test_nb_correct = compute_nb_correct(
self.test_input[:1000], self.test_prior_visits[:1000]
)
##############################################################
+######################################################################
+import world
+
+
+class World(Task):
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ device=torch.device("cpu"),
+ ):
+ self.batch_size = batch_size
+ self.device = device
+
+ (
+ self.train_input,
+ self.train_actions,
+ self.test_input,
+ self.test_actions,
+ self.frame2seq,
+ self.seq2frame,
+ ) = world.create_data_and_processors(
+ nb_train_samples,
+ nb_test_samples,
+ mode="first_last",
+ nb_steps=30,
+ nb_epochs=2,
+ )
+
+ self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+ def batches(self, split="train", nb_to_use=-1, desc=None):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ if nb_to_use > 0:
+ input = input[:nb_to_use]
+ if desc is None:
+ desc = f"epoch-{split}"
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
+ ):
+ yield batch
+
+ def vocabulary_size(self):
+ return self.nb_codes
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ ):
+ pass
+
+
######################################################################
def generate_episodes(nb, steps):
- all_frames = []
+ all_frames, all_actions = [], []
for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
frames, actions = generate_episode(steps)
all_frames += frames
- return torch.cat(all_frames, 0).contiguous()
+ all_actions += [actions]
+ return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0)
-def create_data_and_processors(nb_train_samples, nb_test_samples, nb_epochs=10):
- steps = [True] + [False] * 30 + [True]
- train_input = generate_episodes(nb_train_samples, steps)
- test_input = generate_episodes(nb_test_samples, steps)
+def create_data_and_processors(
+ nb_train_samples, nb_test_samples, mode, nb_steps, nb_epochs=10
+):
+ assert mode in ["first_last"]
+
+ if mode == "first_last":
+ steps = [True] + [False] * (nb_steps + 1) + [True]
+
+ train_input, train_actions = generate_episodes(nb_train_samples, steps)
+ test_input, test_actions = generate_episodes(nb_test_samples, steps)
encoder, quantizer, decoder = train_encoder(
train_input, test_input, nb_epochs=nb_epochs
return torch.cat(frames, dim=0)
- return train_input, test_input, frame2seq, seq2frame
+ return train_input, train_actions, test_input, test_actions, frame2seq, seq2frame
######################################################################
if __name__ == "__main__":
- train_input, test_input, frame2seq, seq2frame = create_data_and_processors(
+ (
+ train_input,
+ train_actions,
+ test_input,
+ test_actions,
+ frame2seq,
+ seq2frame,
+ ) = create_data_and_processors(
# 10000, 1000,
- 100, 100, nb_epochs=2,
+ 100,
+ 100,
+ nb_epochs=2,
+ mode="first_last",
+ nb_steps=20,
)
input = test_input[:64]