parser.add_argument("--expr_input_file", type=str, default=None)
-##############################
-# World options
-
-parser.add_argument("--world_vqae_nb_epochs", type=int, default=25)
-
######################################################################
args = parser.parse_args()
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
-
"mnist": {
"model": "37M",
"batch_size": 10,
"nb_train_samples": 60000,
"nb_test_samples": 10000,
},
- "world": {
- "model": "37M",
- "batch_size": 25,
- "nb_train_samples": 25000,
- "nb_test_samples": 1000,
- },
}
if args.task in default_task_args:
device=device,
)
-elif args.task == "world":
- task = tasks.World(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
- vqae_nb_epochs=args.world_vqae_nb_epochs,
- logger=log_string,
- device=device,
- )
-
else:
raise ValueError(f"Unknown task {args.task}")
# print(f'@2 {i=} {j=}')
+def seq2str(seq):
+ return "".join(["NESW123456789"[i] for i in seq])
+
+
######################################################################
if __name__ == "__main__":
- import cairo, numpy, math
-
- color_name2rgb = {
- "red": [255, 0, 0],
- "green": [0, 128, 0],
- "blue": [0, 0, 255],
- "yellow": [255, 255, 0],
- "orange": [255, 128, 0],
- "maroon": [128, 0, 0],
- "dark_red": [139, 0, 0],
- "brown": [165, 42, 42],
- "firebrick": [178, 34, 34],
- "crimson": [220, 20, 60],
- "tomato": [255, 99, 71],
- "coral": [255, 127, 80],
- "indian_red": [205, 92, 92],
- "light_coral": [240, 128, 128],
- "dark_salmon": [233, 150, 122],
- "salmon": [250, 128, 114],
- }
-
- sequences, sequences_prior_visits, worlds, world_prior_visits = generate_sequences(
- 8, 6, 8, 5, 20, 10
+ train_input, train_prior_visits, _, _ = generate_sequences(
+ nb=20,
+ height=9,
+ width=12,
+ nb_colors=5,
+ length=50,
+ prompt_length=100,
)
- delta = 16
- height, width = sequences.size(0) * 16, sequences.size(1) * 16
- pixel_map = torch.ByteTensor(width, height, 4).fill_(0).numpy()
- surface = cairo.ImageSurface.create_for_data(
- pixel_map, cairo.FORMAT_ARGB32, width, height
- )
- ctx = cairo.Context(surface)
- ctx.set_line_width(1.0)
-
- ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
-
- ctx.fill()
+ print([seq2str(s) for s in train_input])
######################################################################
######################################################################
-
-import world
-
-
-class World(Task):
- def __init__(
- self,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- vqae_nb_epochs,
- logger=None,
- device=torch.device("cpu"),
- device_storage=torch.device("cpu"),
- ):
- super().__init__()
-
- self.batch_size = batch_size
- self.device = device
-
- (
- train_frames,
- train_action_seq,
- test_frames,
- test_action_seq,
- self.frame2seq,
- self.seq2frame,
- ) = world.create_data_and_processors(
- nb_train_samples,
- nb_test_samples,
- mode="first_last",
- nb_steps=30,
- nb_epochs=vqae_nb_epochs,
- logger=logger,
- device=device,
- device_storage=device_storage,
- )
-
- train_frame_seq = self.frame2seq(train_frames).to(device_storage)
- test_frame_seq = self.frame2seq(test_frames).to(device_storage)
-
- nb_frame_codes = max(train_frame_seq.max(), test_frame_seq.max()) + 1
- nb_action_codes = max(train_action_seq.max(), test_action_seq.max()) + 1
-
- self.len_frame_seq = train_frame_seq.size(1)
- self.len_action_seq = train_action_seq.size(1)
- self.nb_codes = nb_frame_codes + nb_action_codes
-
- train_frame_seq = train_frame_seq.reshape(train_frame_seq.size(0) // 2, 2, -1)
-
- train_action_seq += nb_frame_codes
- self.train_input = torch.cat(
- (train_frame_seq[:, 0, :], train_action_seq, train_frame_seq[:, 1, :]), 1
- )
-
- test_frame_seq = test_frame_seq.reshape(test_frame_seq.size(0) // 2, 2, -1)
- test_action_seq += nb_frame_codes
- self.test_input = torch.cat(
- (test_frame_seq[:, 0, :], test_action_seq, test_frame_seq[:, 1, :]), 1
- )
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch.to(self.device)
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- k = torch.arange(
- 2 * self.len_frame_seq + self.len_action_seq, device=self.device
- )[None, :]
-
- input = self.test_input[:64].to(self.device)
- result = input.clone()
-
- ar_mask = (
- (k >= self.len_frame_seq + self.len_action_seq).long().expand_as(result)
- )
- result *= 1 - ar_mask
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
-
- seq_start = input[:, : self.len_frame_seq]
- seq_end = input[:, self.len_frame_seq + self.len_action_seq :]
- seq_predicted = result[:, self.len_frame_seq + self.len_action_seq :]
-
- result = torch.cat(
- (seq_start[:, None, :], seq_end[:, None, :], seq_predicted[:, None, :]), 1
- )
- result = result.reshape(-1, result.size(-1))
-
- frames = self.seq2frame(result)
- image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")
- torchvision.utils.save_image(
- frames.float() / (world.Box.nb_rgb_levels - 1),
- image_name,
- nrow=12,
- padding=1,
- pad_value=0.0,
- )
- logger(f"wrote {image_name}")
-
-
-######################################################################